diff --git a/internal/experiment/dnsping/dnsping_test.go b/internal/experiment/dnsping/dnsping_test.go index 0ecea61157..f24326b04f 100644 --- a/internal/experiment/dnsping/dnsping_test.go +++ b/internal/experiment/dnsping/dnsping_test.go @@ -99,7 +99,7 @@ func TestMeasurer_run(t *testing.T) { t.Run("with netem: without DPI: expect success", func(t *testing.T) { // create a new test environment - env := netemx.MustNewQAEnv(netemx.QAEnvOptionNetStack("8.8.8.8", &netemx.UDPResolverFactory{})) + env := netemx.MustNewQAEnv(netemx.QAEnvOptionNetStack("8.8.8.8", &netemx.DNSOverUDPServerFactory{})) defer env.Close() // we use the same configuration for all resolvers @@ -146,7 +146,7 @@ func TestMeasurer_run(t *testing.T) { t.Run("with netem: with DNS spoofing: expect to see delayed responses", func(t *testing.T) { // create a new test environment - env := netemx.MustNewQAEnv(netemx.QAEnvOptionNetStack("8.8.8.8", &netemx.UDPResolverFactory{})) + env := netemx.MustNewQAEnv(netemx.QAEnvOptionNetStack("8.8.8.8", &netemx.DNSOverUDPServerFactory{})) defer env.Close() // we use the same configuration for all resolvers diff --git a/internal/netemx/dnsoverhttps.go b/internal/netemx/dnsoverhttps.go index 70ace0da76..4016105b86 100644 --- a/internal/netemx/dnsoverhttps.go +++ b/internal/netemx/dnsoverhttps.go @@ -14,5 +14,7 @@ var _ HTTPHandlerFactory = &DNSOverHTTPSHandlerFactory{} // NewHandler implements QAEnvHTTPHandlerFactory. func (f *DNSOverHTTPSHandlerFactory) NewHandler(env NetStackServerFactoryEnv, stack *netem.UNetStack) http.Handler { - return &testingx.DNSOverHTTPSHandler{Config: env.OtherResolversConfig()} + return &testingx.DNSOverHTTPSHandler{ + RoundTripper: testingx.NewDNSRoundTripperWithDNSConfig(env.OtherResolversConfig()), + } } diff --git a/internal/netemx/udpresolver.go b/internal/netemx/dnsoverudp.go similarity index 59% rename from internal/netemx/udpresolver.go rename to internal/netemx/dnsoverudp.go index 14d44acfab..2f23647d10 100644 --- a/internal/netemx/udpresolver.go +++ b/internal/netemx/dnsoverudp.go @@ -11,37 +11,37 @@ import ( "github.com/ooni/probe-cli/v3/internal/runtimex" ) -// UDPResolverFactory implements [NetStackServerFactory] for DNS-over-UDP servers. +// DNSOverUDPServerFactory implements [NetStackServerFactory] for DNS-over-UDP servers. // // When this factory constructs a [NetStackServer], it will use: // // 1. the [NetStackServerFactoryEnv.OtherResolversConfig] as DNS configuration; // -// 2. the [NetStackServerFactoryEnv.Logger] as the logger. +// 2. the [NetStackServerFactoryEnv.Logger] as logger. // // Use this factory along with [QAEnvOptionNetStack] to create DNS-over-UDP servers. -type UDPResolverFactory struct{} +type DNSOverUDPServerFactory struct{} -var _ NetStackServerFactory = &UDPResolverFactory{} +var _ NetStackServerFactory = &DNSOverUDPServerFactory{} // MustNewServer implements NetStackServerFactory. -func (f *UDPResolverFactory) MustNewServer(env NetStackServerFactoryEnv, stack *netem.UNetStack) NetStackServer { - return udpResolverMustNewServer(env.OtherResolversConfig(), env.Logger(), stack) +func (f *DNSOverUDPServerFactory) MustNewServer(env NetStackServerFactoryEnv, stack *netem.UNetStack) NetStackServer { + return dnsOverUDPResolverMustNewServer(env.OtherResolversConfig(), env.Logger(), stack) } -type udpResolverFactoryForGetaddrinfo struct{} +type dnsOverUDPServerFactoryForGetaddrinfo struct{} -var _ NetStackServerFactory = &udpResolverFactoryForGetaddrinfo{} +var _ NetStackServerFactory = &dnsOverUDPServerFactoryForGetaddrinfo{} // MustNewServer implements NetStackServerFactory. -func (f *udpResolverFactoryForGetaddrinfo) MustNewServer(env NetStackServerFactoryEnv, stack *netem.UNetStack) NetStackServer { - return udpResolverMustNewServer(env.ISPResolverConfig(), env.Logger(), stack) +func (f *dnsOverUDPServerFactoryForGetaddrinfo) MustNewServer(env NetStackServerFactoryEnv, stack *netem.UNetStack) NetStackServer { + return dnsOverUDPResolverMustNewServer(env.ISPResolverConfig(), env.Logger(), stack) } -// udpResolverMustNewServer is an internal factory for creating a [NetStackServer] that +// dnsOverUDPResolverMustNewServer is an internal factory for creating a [NetStackServer] that // runs a DNS-over-UDP server using the configured logger, DNS config, and stack. -func udpResolverMustNewServer(config *netem.DNSConfig, logger model.Logger, stack *netem.UNetStack) NetStackServer { - return &udpResolver{ +func dnsOverUDPResolverMustNewServer(config *netem.DNSConfig, logger model.Logger, stack *netem.UNetStack) NetStackServer { + return &dnsOverUDPResolver{ closers: []io.Closer{}, config: config, logger: logger, @@ -50,7 +50,7 @@ func udpResolverMustNewServer(config *netem.DNSConfig, logger model.Logger, stac } } -type udpResolver struct { +type dnsOverUDPResolver struct { closers []io.Closer config *netem.DNSConfig logger model.Logger @@ -59,7 +59,7 @@ type udpResolver struct { } // Close implements NetStackServer. -func (srv *udpResolver) Close() error { +func (srv *dnsOverUDPResolver) Close() error { // make the method locked as requested by the documentation defer srv.mu.Unlock() srv.mu.Lock() @@ -75,7 +75,7 @@ func (srv *udpResolver) Close() error { } // MustStart implements NetStackServer. -func (srv *udpResolver) MustStart() { +func (srv *dnsOverUDPResolver) MustStart() { // make the method locked as requested by the documentation defer srv.mu.Unlock() srv.mu.Lock() diff --git a/internal/netemx/udpresolver_test.go b/internal/netemx/dnsoverudp_test.go similarity index 84% rename from internal/netemx/udpresolver_test.go rename to internal/netemx/dnsoverudp_test.go index 9f3b2321c6..6d0499cef8 100644 --- a/internal/netemx/udpresolver_test.go +++ b/internal/netemx/dnsoverudp_test.go @@ -10,9 +10,9 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite" ) -func TestUDPResolverFactory(t *testing.T) { +func TestDNSOverUDPServerFactory(t *testing.T) { env := MustNewQAEnv( - QAEnvOptionNetStack(AddressDNSGoogle8844, &UDPResolverFactory{}), + QAEnvOptionNetStack(AddressDNSGoogle8844, &DNSOverUDPServerFactory{}), ) defer env.Close() diff --git a/internal/netemx/example_test.go b/internal/netemx/example_test.go index 8b3c9a30df..76684d10ba 100644 --- a/internal/netemx/example_test.go +++ b/internal/netemx/example_test.go @@ -20,8 +20,8 @@ import ( // to use this QA environment in all the examples for this package. func exampleNewEnvironment() *netemx.QAEnv { return netemx.MustNewQAEnv( - netemx.QAEnvOptionNetStack("8.8.4.4", &netemx.UDPResolverFactory{}), - netemx.QAEnvOptionNetStack("9.9.9.9", &netemx.UDPResolverFactory{}), + netemx.QAEnvOptionNetStack("8.8.4.4", &netemx.DNSOverUDPServerFactory{}), + netemx.QAEnvOptionNetStack("9.9.9.9", &netemx.DNSOverUDPServerFactory{}), netemx.QAEnvOptionClientAddress(netemx.DefaultClientAddress), netemx.QAEnvOptionHTTPServer( netemx.AddressWwwExampleCom, netemx.ExampleWebPageHandlerFactory()), diff --git a/internal/netemx/qaenv.go b/internal/netemx/qaenv.go index 02e899ee4d..0e03536516 100644 --- a/internal/netemx/qaenv.go +++ b/internal/netemx/qaenv.go @@ -182,10 +182,10 @@ func MustNewQAEnv(options ...QAEnvOption) *QAEnv { } // make sure we're going to create the ISP's DNS resolver. - qaEnvOptionNetStack(config.ispResolver, &udpResolverFactoryForGetaddrinfo{})(config) + qaEnvOptionNetStack(config.ispResolver, &dnsOverUDPServerFactoryForGetaddrinfo{})(config) // make sure we're going to create the root DNS resolver. - qaEnvOptionNetStack(config.rootResolver, &UDPResolverFactory{})(config) + qaEnvOptionNetStack(config.rootResolver, &DNSOverUDPServerFactory{})(config) // use a prefix logger for the QA env prefixLogger := &logx.PrefixLogger{ diff --git a/internal/netemx/scenario.go b/internal/netemx/scenario.go index 756a12f5fd..364a7f65ac 100644 --- a/internal/netemx/scenario.go +++ b/internal/netemx/scenario.go @@ -125,7 +125,7 @@ func MustNewScenario(config []*ScenarioDomainAddresses) *QAEnv { for _, addr := range sad.Addresses { opts = append(opts, QAEnvOptionNetStack( addr, - &UDPResolverFactory{}, + &DNSOverUDPServerFactory{}, &HTTPSecureServerFactory{ Factory: &DNSOverHTTPSHandlerFactory{}, Ports: []int{443}, diff --git a/internal/netxlite/dnsoverudp_test.go b/internal/netxlite/dnsoverudp_test.go index 67e7ff73a8..7ab12d50fa 100644 --- a/internal/netxlite/dnsoverudp_test.go +++ b/internal/netxlite/dnsoverudp_test.go @@ -12,9 +12,9 @@ import ( "github.com/apex/log" "github.com/google/go-cmp/cmp" "github.com/miekg/dns" + "github.com/ooni/netem" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" - "github.com/ooni/probe-cli/v3/internal/netxlite/filtering" "github.com/ooni/probe-cli/v3/internal/testingx" ) @@ -252,18 +252,14 @@ func TestDNSOverUDPTransport(t *testing.T) { }) t.Run("using a real server", func(t *testing.T) { - srvr := &filtering.DNSServer{ - OnQuery: func(domain string) filtering.DNSAction { - return filtering.DNSActionCache - }, - Cache: map[string][]string{ - "dns.google.": {"8.8.8.8"}, - }, - } - listener, err := srvr.Start("127.0.0.1:0") - if err != nil { - t.Fatal(err) - } + udpAddr := &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 0, + } + dnsConfig := netem.NewDNSConfig() + dnsConfig.AddRecord("dns.google", "", "8.8.8.8") + dnsRtx := testingx.NewDNSRoundTripperWithDNSConfig(dnsConfig) + listener := testingx.MustNewDNSOverUDPListener(udpAddr, &testingx.DNSOverUDPStdlibListener{}, dnsRtx) defer listener.Close() dialer := NewDialerWithoutResolver(model.DiscardLogger) txp := NewUnwrappedDNSOverUDPTransport(dialer, listener.LocalAddr().String()) @@ -284,20 +280,20 @@ func TestDNSOverUDPTransport(t *testing.T) { }) t.Run("recording delayed DNS responses", func(t *testing.T) { + dnsConfigGood := netem.NewDNSConfig() + dnsConfigGood.AddRecord("dns.google", "", "8.8.8.8") + + dnsConfigBogus := netem.NewDNSConfig() + dnsConfigBogus.AddRecord("dns.google", "", "127.0.0.1") + t.Run("without any context-injected traces", func(t *testing.T) { - srvr := &filtering.DNSServer{ - OnQuery: func(domain string) filtering.DNSAction { - return filtering.DNSActionLocalHostPlusCache - }, - Cache: map[string][]string{ - "dns.google.": {"8.8.8.8"}, - }, - } - listener, err := srvr.Start("127.0.0.1:0") - if err != nil { - t.Fatal(err) + udpAddr := &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 0, } - defer listener.Close() + listener := testingx.MustNewDNSSimulateGWFListener( + udpAddr, &testingx.DNSOverUDPStdlibListener{}, dnsConfigBogus, + dnsConfigGood, testingx.DNSNumBogusResponses(1)) dialer := NewDialerWithoutResolver(model.DiscardLogger) expectedAddress := listener.LocalAddr().String() txp := NewUnwrappedDNSOverUDPTransport(dialer, expectedAddress) @@ -325,18 +321,13 @@ func TestDNSOverUDPTransport(t *testing.T) { goodLookupAddrs bool goodError bool ) - srvr := &filtering.DNSServer{ - OnQuery: func(domain string) filtering.DNSAction { - return filtering.DNSActionLocalHostPlusCache - }, - Cache: map[string][]string{ - "dns.google.": {"8.8.8.8"}, - }, - } - listener, err := srvr.Start("127.0.0.1:0") - if err != nil { - t.Fatal(err) + udpAddr := &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 0, } + listener := testingx.MustNewDNSSimulateGWFListener( + udpAddr, &testingx.DNSOverUDPStdlibListener{}, dnsConfigBogus, + dnsConfigGood, testingx.DNSNumBogusResponses(1)) defer listener.Close() dialer := NewDialerWithoutResolver(model.DiscardLogger) expectedAddress := listener.LocalAddr().String() @@ -420,19 +411,14 @@ func TestDNSOverUDPTransport(t *testing.T) { goodLookupAddrs bool goodError bool ) - srvr := &filtering.DNSServer{ - OnQuery: func(domain string) filtering.DNSAction { - return filtering.DNSActionLocalHostPlusCache - }, - Cache: map[string][]string{ - // Note: the cache here is nonexistent so we should - // get a "no such host" error from the server. - }, - } - listener, err := srvr.Start("127.0.0.1:0") - if err != nil { - t.Fatal(err) - } + udpAddr := &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 0, + } + // Note: the config here is empty so we should get a "no such host" error from the server. + listener := testingx.MustNewDNSSimulateGWFListener( + udpAddr, &testingx.DNSOverUDPStdlibListener{}, dnsConfigBogus, + netem.NewDNSConfig(), testingx.DNSNumBogusResponses(1)) defer listener.Close() dialer := NewDialerWithoutResolver(model.DiscardLogger) expectedAddress := listener.LocalAddr().String() diff --git a/internal/netxlite/filtering/dns.go b/internal/netxlite/filtering/dns.go index 26c1e87736..b0248f2074 100644 --- a/internal/netxlite/filtering/dns.go +++ b/internal/netxlite/filtering/dns.go @@ -1,165 +1,13 @@ package filtering import ( - "io" "net" "strings" - "time" "github.com/miekg/dns" "github.com/ooni/probe-cli/v3/internal/runtimex" ) -// DNSAction is a DNS filtering action that a DNSServer should take. -type DNSAction string - -const ( - // DNSActionNXDOMAIN replies with NXDOMAIN. - DNSActionNXDOMAIN = DNSAction("nxdomain") - - // DNSActionRefused replies with Refused. - DNSActionRefused = DNSAction("refused") - - // DNSActionLocalHost replies with `127.0.0.1` and `::1`. - DNSActionLocalHost = DNSAction("localhost") - - // DNSActionNoAnswer returns an empty reply. - DNSActionNoAnswer = DNSAction("no-answer") - - // DNSActionTimeout never replies to the query. - DNSActionTimeout = DNSAction("timeout") - - // DNSActionCache causes the server to check the cache. If there - // are entries, they are returned. Otherwise, NXDOMAIN is returned. - DNSActionCache = DNSAction("cache") - - // DNSActionLocalHostPlusCache combines the LocalHost and - // Cache actions returning first a localhost response followed - // by a subsequent response obtained using the cache. - DNSActionLocalHostPlusCache = DNSAction("localhost+cache") -) - -// DNSServer is a DNS server implementing filtering policies. -type DNSServer struct { - // Cache is the OPTIONAL DNS cache. Note that the keys of the map - // must be FQDNs (i.e., including the final `.`). - Cache map[string][]string - - // OnQuery is the MANDATORY hook called whenever we - // receive a query for the given domain. - OnQuery func(domain string) DNSAction - - // onTimeout is the OPTIONAL channel where we emit a true - // value each time there's a timeout. If you set this value - // to a non-nil channel, then you MUST drain the channel - // for each expected timeout. Otherwise, the code will just - // ignore this field and nothing will be emitted. - onTimeout chan bool -} - -// DNSListener is the interface returned by DNSServer.Start. -type DNSListener interface { - io.Closer - LocalAddr() net.Addr -} - -// Start starts this server. -func (p *DNSServer) Start(address string) (DNSListener, error) { - pconn, _, err := p.start(address) - return pconn, err -} - -func (p *DNSServer) start(address string) (DNSListener, <-chan interface{}, error) { - pconn, err := net.ListenPacket("udp", address) - if err != nil { - return nil, nil, err - } - done := make(chan interface{}) - go p.mainloop(pconn, done) - return pconn, done, nil -} - -func (p *DNSServer) mainloop(pconn net.PacketConn, done chan<- interface{}) { - defer close(done) - for p.oneloop(pconn) { - // nothing - } -} - -func (p *DNSServer) oneloop(pconn net.PacketConn) bool { - buffer := make([]byte, 1<<17) - count, addr, err := pconn.ReadFrom(buffer) - if err != nil { - return !strings.HasSuffix(err.Error(), "use of closed network connection") - } - buffer = buffer[:count] - go p.serveAsync(pconn, addr, buffer) - return true -} - -func (p *DNSServer) emit(pconn net.PacketConn, addr net.Addr, reply ...*dns.Msg) (success int) { - for _, entry := range reply { - replyBytes, err := entry.Pack() - if err != nil { - continue - } - pconn.WriteTo(replyBytes, addr) - success++ // we use this value in tests - } - return -} - -func (p *DNSServer) serveAsync(pconn net.PacketConn, addr net.Addr, buffer []byte) { - query := &dns.Msg{} - if err := query.Unpack(buffer); err != nil { - return - } - if len(query.Question) < 1 { - return // just discard the query - } - name := query.Question[0].Name - switch p.OnQuery(name) { - case DNSActionNXDOMAIN: - p.emit(pconn, addr, p.nxdomain(query)) - case DNSActionLocalHost: - p.emit(pconn, addr, p.localHost(query)) - case DNSActionNoAnswer: - p.emit(pconn, addr, p.empty(query)) - case DNSActionTimeout: - if p.onTimeout != nil { - p.onTimeout <- true - } - case DNSActionCache: - p.emit(pconn, addr, p.cache(name, query)) - case DNSActionLocalHostPlusCache: - p.emit(pconn, addr, p.localHost(query)) - time.Sleep(10 * time.Millisecond) - p.emit(pconn, addr, p.cache(name, query)) - default: - p.emit(pconn, addr, p.refused(query)) - } -} - -func (p *DNSServer) refused(query *dns.Msg) *dns.Msg { - m := new(dns.Msg) - m.SetRcode(query, dns.RcodeRefused) - return m -} - -func (p *DNSServer) nxdomain(query *dns.Msg) *dns.Msg { - m := new(dns.Msg) - m.SetRcode(query, dns.RcodeNameError) - return m -} - -func (p *DNSServer) localHost(query *dns.Msg) *dns.Msg { - return DNSComposeResponse(query, net.IPv6loopback, net.IPv4(127, 0, 0, 1)) -} - -func (p *DNSServer) empty(query *dns.Msg) *dns.Msg { - return DNSComposeResponse(query) -} - func dnsComposeQuery(domain string, qtype uint16) *dns.Msg { question := dns.Question{ Name: dns.Fqdn(domain), @@ -177,6 +25,8 @@ func dnsComposeQuery(domain string, qtype uint16) *dns.Msg { // DNSComposeResponse composes a DNS response using the given IP addresses. If given no // addresses, this function returns a successful, empty response. This function PANICS if // the query argument does not contain EXACTLY one question. +// +// Deprecated: do not use this function in new code func DNSComposeResponse(query *dns.Msg, ips ...net.IP) *dns.Msg { runtimex.PanicIfTrue(len(query.Question) != 1, "expecting a single question") question := query.Question[0] @@ -210,17 +60,3 @@ func DNSComposeResponse(query *dns.Msg, ips ...net.IP) *dns.Msg { } return reply } - -func (p *DNSServer) cache(name string, query *dns.Msg) *dns.Msg { - addrs := p.Cache[name] - var ipAddrs []net.IP - for _, addr := range addrs { - if ip := net.ParseIP(addr); ip != nil { - ipAddrs = append(ipAddrs, ip) - } - } - if len(ipAddrs) <= 0 { - return p.nxdomain(query) - } - return DNSComposeResponse(query, ipAddrs...) -} diff --git a/internal/netxlite/filtering/dns_test.go b/internal/netxlite/filtering/dns_test.go deleted file mode 100644 index 4393ad9d5e..0000000000 --- a/internal/netxlite/filtering/dns_test.go +++ /dev/null @@ -1,341 +0,0 @@ -package filtering - -import ( - "errors" - "net" - "strings" - "testing" - - "github.com/miekg/dns" - "github.com/ooni/probe-cli/v3/internal/mocks" - "github.com/ooni/probe-cli/v3/internal/randx" -) - -func TestDNSServer(t *testing.T) { - newServerWithCache := func(action DNSAction, cache map[string][]string) ( - *DNSServer, DNSListener, <-chan interface{}, error) { - p := &DNSServer{ - Cache: cache, - OnQuery: func(domain string) DNSAction { - return action - }, - onTimeout: make(chan bool), - } - listener, done, err := p.start("127.0.0.1:0") - return p, listener, done, err - } - - newServer := func(action DNSAction) (*DNSServer, DNSListener, <-chan interface{}, error) { - return newServerWithCache(action, nil) - } - - newQuery := func(qtype uint16) *dns.Msg { - question := dns.Question{ - Name: dns.Fqdn("dns.google"), - Qtype: qtype, - Qclass: dns.ClassINET, - } - query := new(dns.Msg) - query.Id = dns.Id() - query.RecursionDesired = true - query.Question = make([]dns.Question, 1) - query.Question[0] = question - return query - } - - t.Run("DNSActionNXDOMAIN", func(t *testing.T) { - _, listener, done, err := newServer(DNSActionNXDOMAIN) - if err != nil { - t.Fatal(err) - } - reply, err := dns.Exchange(newQuery(dns.TypeA), listener.LocalAddr().String()) - if err != nil { - t.Fatal(err) - } - if reply.Rcode != dns.RcodeNameError { - t.Fatal("unexpected rcode") - } - listener.Close() - <-done // wait for background goroutine to exit - }) - - t.Run("DNSActionRefused", func(t *testing.T) { - _, listener, done, err := newServer(DNSActionRefused) - if err != nil { - t.Fatal(err) - } - reply, err := dns.Exchange(newQuery(dns.TypeA), listener.LocalAddr().String()) - if err != nil { - t.Fatal(err) - } - if reply.Rcode != dns.RcodeRefused { - t.Fatal("unexpected rcode") - } - listener.Close() - <-done // wait for background goroutine to exit - }) - - t.Run("DNSActionLocalHost", func(t *testing.T) { - _, listener, done, err := newServer(DNSActionLocalHost) - if err != nil { - t.Fatal(err) - } - reply, err := dns.Exchange(newQuery(dns.TypeA), listener.LocalAddr().String()) - if err != nil { - t.Fatal(err) - } - if reply.Rcode != dns.RcodeSuccess { - t.Fatal("unexpected rcode") - } - var found bool - for _, ans := range reply.Answer { - switch v := ans.(type) { - case *dns.A: - found = found || v.A.String() == "127.0.0.1" - } - } - if !found { - t.Fatal("did not find 127.0.0.1") - } - listener.Close() - <-done // wait for background goroutine to exit - }) - - t.Run("DNSActionEmpty", func(t *testing.T) { - _, listener, done, err := newServer(DNSActionNoAnswer) - if err != nil { - t.Fatal(err) - } - reply, err := dns.Exchange(newQuery(dns.TypeA), listener.LocalAddr().String()) - if err != nil { - t.Fatal(err) - } - if reply.Rcode != dns.RcodeSuccess { - t.Fatal("unexpected rcode") - } - if len(reply.Answer) != 0 { - t.Fatal("expected no answers") - } - listener.Close() - <-done // wait for background goroutine to exit - }) - - t.Run("DNSActionTimeout", func(t *testing.T) { - srvr, listener, done, err := newServer(DNSActionTimeout) - if err != nil { - t.Fatal(err) - } - c := &dns.Client{} - conn, err := c.Dial(listener.LocalAddr().String()) - if err != nil { - t.Fatal(err) - } - go func() { - <-srvr.onTimeout - conn.Close() // close as soon as the server times out, so this test is fast - }() - reply, _, err := c.ExchangeWithConn(newQuery(dns.TypeA), conn) - if !errors.Is(err, net.ErrClosed) { - t.Fatal("unexpected err", err) - } - if reply != nil { - t.Fatal("expected nil reply here") - } - listener.Close() - <-done // wait for background goroutine to exit - }) - - t.Run("DNSActionCache without entries", func(t *testing.T) { - _, listener, done, err := newServerWithCache(DNSActionCache, nil) - if err != nil { - t.Fatal(err) - } - reply, err := dns.Exchange(newQuery(dns.TypeA), listener.LocalAddr().String()) - if err != nil { - t.Fatal(err) - } - if reply.Rcode != dns.RcodeNameError { - t.Fatal("unexpected rcode") - } - listener.Close() - <-done // wait for background goroutine to exit - }) - - t.Run("DNSActionCache with IPv4 entry", func(t *testing.T) { - cache := map[string][]string{ - "dns.google.": {"8.8.8.8"}, - } - _, listener, done, err := newServerWithCache(DNSActionCache, cache) - if err != nil { - t.Fatal(err) - } - reply, err := dns.Exchange(newQuery(dns.TypeA), listener.LocalAddr().String()) - if err != nil { - t.Fatal(err) - } - if reply.Rcode != dns.RcodeSuccess { - t.Fatal("unexpected rcode") - } - var found bool - for _, ans := range reply.Answer { - switch v := ans.(type) { - case *dns.A: - found = found || v.A.String() == "8.8.8.8" - } - } - if !found { - t.Fatal("did not find 8.8.8.8") - } - listener.Close() - <-done // wait for background goroutine to exit - }) - - t.Run("DNSActionCache with IPv6 entry", func(t *testing.T) { - cache := map[string][]string{ - "dns.google.": {"2001:4860:4860::8888"}, - } - _, listener, done, err := newServerWithCache(DNSActionCache, cache) - if err != nil { - t.Fatal(err) - } - reply, err := dns.Exchange(newQuery(dns.TypeAAAA), listener.LocalAddr().String()) - if err != nil { - t.Fatal(err) - } - if reply.Rcode != dns.RcodeSuccess { - t.Fatal("unexpected rcode") - } - var found bool - for _, ans := range reply.Answer { - switch v := ans.(type) { - case *dns.AAAA: - found = found || v.AAAA.String() == "2001:4860:4860::8888" - } - } - if !found { - t.Fatal("did not find 2001:4860:4860::8888") - } - listener.Close() - <-done // wait for background goroutine to exit - }) - - t.Run("DNSActionLocalHostPlusCache", func(t *testing.T) { - cache := map[string][]string{ - "dns.google.": {"2001:4860:4860::8888"}, - } - _, listener, done, err := newServerWithCache(DNSActionLocalHostPlusCache, cache) - if err != nil { - t.Fatal(err) - } - reply, err := dns.Exchange(newQuery(dns.TypeAAAA), listener.LocalAddr().String()) - if err != nil { - t.Fatal(err) - } - if reply.Rcode != dns.RcodeSuccess { - t.Fatal("unexpected rcode") - } - var found bool - for _, ans := range reply.Answer { - switch v := ans.(type) { - case *dns.AAAA: - found = found || v.AAAA.String() == "::1" - } - } - if !found { - t.Fatal("did not find ::1") - } - listener.Close() - <-done // wait for background goroutine to exit - }) - - t.Run("Start with invalid address", func(t *testing.T) { - p := &DNSServer{} - listener, err := p.Start("127.0.0.1") - if err == nil || !strings.HasSuffix(err.Error(), "missing port in address") { - t.Fatal("unexpected err", err) - } - if listener != nil { - t.Fatal("expected nil listener") - } - }) - - t.Run("oneloop", func(t *testing.T) { - t.Run("ReadFrom failure after which we should continue", func(t *testing.T) { - expected := errors.New("mocked error") - p := &DNSServer{} - conn := &mocks.UDPLikeConn{ - MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { - return 0, nil, expected - }, - } - okay := p.oneloop(conn) - if !okay { - t.Fatal("we should be okay after this error") - } - }) - - t.Run("ReadFrom the connection is closed", func(t *testing.T) { - expected := errors.New("use of closed network connection") - p := &DNSServer{} - conn := &mocks.UDPLikeConn{ - MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { - return 0, nil, expected - }, - } - okay := p.oneloop(conn) - if okay { - t.Fatal("we should not be okay after this error") - } - }) - - t.Run("Unpack fails", func(t *testing.T) { - p := &DNSServer{} - conn := &mocks.UDPLikeConn{ - MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { - if len(p) < 4 { - panic("buffer too small") - } - p[0] = 7 - return 1, &net.UDPAddr{}, nil - }, - } - okay := p.oneloop(conn) - if !okay { - t.Fatal("we should be okay after this error") - } - }) - - t.Run("no questions", func(t *testing.T) { - query := newQuery(dns.TypeA) - query.Question = nil // remove the question - data, err := query.Pack() - if err != nil { - t.Fatal(err) - } - p := &DNSServer{} - conn := &mocks.UDPLikeConn{ - MockReadFrom: func(p []byte) (n int, addr net.Addr, err error) { - if len(p) < len(data) { - panic("buffer too small") - } - copy(p, data) - return len(data), &net.UDPAddr{}, nil - }, - } - okay := p.oneloop(conn) - if !okay { - t.Fatal("we should be okay after this error") - } - }) - }) - - t.Run("pack fails", func(t *testing.T) { - query := newQuery(dns.TypeA) - query.Question[0].Name = randx.Letters(1024) // should be too large - p := &DNSServer{} - count := p.emit(&mocks.UDPLikeConn{}, &mocks.Addr{}, query) - if count != 0 { - t.Fatal("expected to see zero here") - } - }) -} diff --git a/internal/netxlite/filtering/testdata/invalid.json b/internal/netxlite/filtering/testdata/invalid.json deleted file mode 100644 index 98232c64fc..0000000000 --- a/internal/netxlite/filtering/testdata/invalid.json +++ /dev/null @@ -1 +0,0 @@ -{ diff --git a/internal/netxlite/filtering/testdata/valid.json b/internal/netxlite/filtering/testdata/valid.json deleted file mode 100644 index ab9d3fdba0..0000000000 --- a/internal/netxlite/filtering/testdata/valid.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "DNSCache": { - "dns.google": ["8.8.8.8", "8.8.4.4"] - }, - "Domains": { - "x.org": "pass" - } -} diff --git a/internal/netxlite/integration_test.go b/internal/netxlite/integration_test.go index b241373129..6c9000c5bc 100644 --- a/internal/netxlite/integration_test.go +++ b/internal/netxlite/integration_test.go @@ -3,6 +3,7 @@ package netxlite_test import ( "context" "crypto/tls" + "errors" "fmt" "net" "net/http" @@ -18,6 +19,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite/quictesting" "github.com/ooni/probe-cli/v3/internal/randx" "github.com/ooni/probe-cli/v3/internal/runtimex" + "github.com/ooni/probe-cli/v3/internal/testingx" "github.com/quic-go/quic-go" utls "gitlab.com/yawning/utls.git" ) @@ -117,15 +119,12 @@ func TestMeasureWithUDPResolver(t *testing.T) { }) t.Run("for nxdomain", func(t *testing.T) { - proxy := &filtering.DNSServer{ - OnQuery: func(domain string) filtering.DNSAction { - return filtering.DNSActionNXDOMAIN - }, - } - listener, err := proxy.Start("127.0.0.1:0") - if err != nil { - t.Fatal(err) + udpAddr := &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 0, } + dnsRtx := testingx.NewDNSRoundTripperNXDOMAIN() + listener := testingx.MustNewDNSOverUDPListener(udpAddr, &testingx.DNSOverUDPStdlibListener{}, dnsRtx) defer listener.Close() dlr := netxlite.NewDialerWithoutResolver(log.Log) r := netxlite.NewParallelUDPResolver(log.Log, dlr, listener.LocalAddr().String()) @@ -141,15 +140,12 @@ func TestMeasureWithUDPResolver(t *testing.T) { }) t.Run("for refused", func(t *testing.T) { - proxy := &filtering.DNSServer{ - OnQuery: func(domain string) filtering.DNSAction { - return filtering.DNSActionRefused - }, - } - listener, err := proxy.Start("127.0.0.1:0") - if err != nil { - t.Fatal(err) + udpAddr := &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 0, } + dnsRtx := testingx.NewDNSRoundTripperRefused() + listener := testingx.MustNewDNSOverUDPListener(udpAddr, &testingx.DNSOverUDPStdlibListener{}, dnsRtx) defer listener.Close() dlr := netxlite.NewDialerWithoutResolver(log.Log) r := netxlite.NewParallelUDPResolver(log.Log, dlr, listener.LocalAddr().String()) @@ -165,15 +161,12 @@ func TestMeasureWithUDPResolver(t *testing.T) { }) t.Run("for timeout", func(t *testing.T) { - proxy := &filtering.DNSServer{ - OnQuery: func(domain string) filtering.DNSAction { - return filtering.DNSActionTimeout - }, - } - listener, err := proxy.Start("127.0.0.1:0") - if err != nil { - t.Fatal(err) + udpAddr := &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 0, } + dnsRtx := testingx.NewDNSRoundTripperSimulateTimeout(time.Millisecond, errors.New("mocked error")) + listener := testingx.MustNewDNSOverUDPListener(udpAddr, &testingx.DNSOverUDPStdlibListener{}, dnsRtx) defer listener.Close() dlr := netxlite.NewDialerWithoutResolver(log.Log) r := netxlite.NewParallelUDPResolver(log.Log, dlr, listener.LocalAddr().String()) diff --git a/internal/testingx/dnscore.go b/internal/testingx/dnscore.go new file mode 100644 index 0000000000..60262f9d09 --- /dev/null +++ b/internal/testingx/dnscore.go @@ -0,0 +1,85 @@ +package testingx + +import ( + "context" + "os" + "time" + + "github.com/miekg/dns" + "github.com/ooni/netem" +) + +// DNSRoundTripper performs DNS round trips. +type DNSRoundTripper interface { + RoundTrip(ctx context.Context, req []byte) (resp []byte, err error) +} + +// DNSRoundTripperFunc makes a func implement the [DNSRoundTripper] interface. +type DNSRoundTripperFunc func(ctx context.Context, req []byte) (resp []byte, err error) + +var _ DNSRoundTripper = DNSRoundTripperFunc(nil) + +// RoundTrip implements DNSRoundTripper. +func (fx DNSRoundTripperFunc) RoundTrip(ctx context.Context, req []byte) (resp []byte, err error) { + return fx(ctx, req) +} + +// NewDNSRoundTripperWithDNSConfig implements [DNSRroundTripper] using a [*netem.DNSConfig]. +func NewDNSRoundTripperWithDNSConfig(config *netem.DNSConfig) DNSRoundTripper { + return &dnsRoundTripperWithDNSConfig{config} +} + +type dnsRoundTripperWithDNSConfig struct { + config *netem.DNSConfig +} + +// RoundTrip implements DNSRoundTripper. +func (rtx *dnsRoundTripperWithDNSConfig) RoundTrip(ctx context.Context, req []byte) (resp []byte, err error) { + return netem.DNSServerRoundTrip(rtx.config, req) +} + +// NewDNSRoundTripperEmptyRespnse is a [DNSRoundTripper] that always returns an empty response. +func NewDNSRoundTripperEmptyRespnse() DNSRoundTripper { + return DNSRoundTripperFunc(func(ctx context.Context, rawReq []byte) (rawResp []byte, err error) { + req := &dns.Msg{} + if err := req.Unpack(rawReq); err != nil { + return nil, err + } + resp := &dns.Msg{} + resp.SetRcode(req, dns.RcodeSuccess) + // without any additional RRs + return resp.Pack() + }) +} + +// NewDNSRoundTripperNXDOMAIN is a [DNSRoundTripper] that always returns NXDOMAIN. +func NewDNSRoundTripperNXDOMAIN() DNSRoundTripper { + // An empty DNS config always causes a NXDOMAIN response + return NewDNSRoundTripperWithDNSConfig(netem.NewDNSConfig()) +} + +// NewDNSRoundTripperRefused is a [DNSRoundTripper] that always returns refused. +func NewDNSRoundTripperRefused() DNSRoundTripper { + return DNSRoundTripperFunc(func(ctx context.Context, rawReq []byte) (rawResp []byte, err error) { + req := &dns.Msg{} + if err := req.Unpack(rawReq); err != nil { + return nil, err + } + resp := &dns.Msg{} + resp.SetRcode(req, dns.RcodeRefused) + return resp.Pack() + }) +} + +// NewDNSRoundTripperSimulateTimeout is a [DNSRoundTripper] that sleeps for the given amount +// of time and then returns to the caller the given error. +func NewDNSRoundTripperSimulateTimeout(timeout time.Duration, err error) DNSRoundTripper { + return DNSRoundTripperFunc(func(ctx context.Context, req []byte) (resp []byte, err error) { + select { + case <-time.After(timeout): + return nil, os.ErrDeadlineExceeded + case <-ctx.Done(): + return nil, ctx.Err() + } + }) +} diff --git a/internal/testingx/dnscore_test.go b/internal/testingx/dnscore_test.go new file mode 100644 index 0000000000..c9e0927447 --- /dev/null +++ b/internal/testingx/dnscore_test.go @@ -0,0 +1,23 @@ +package testingx + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestNewDNSRoundTripSimulateTimeout(t *testing.T) { + t.Run("when the context has already been cancelled", func(t *testing.T) { + rtx := NewDNSRoundTripperSimulateTimeout(time.Second, errors.New("mocked error")) + ctx, cancel := context.WithCancel(context.Background()) + cancel() // immediately cancel + resp, err := rtx.RoundTrip(ctx, make([]byte, 128)) + if !errors.Is(err, context.Canceled) { + t.Fatal("unexpected err", err) + } + if len(resp) != 0 { + t.Fatal("expected zero-byte resp") + } + }) +} diff --git a/internal/testingx/dnsoverhttps.go b/internal/testingx/dnsoverhttps.go index 15cf7dc9f5..7f334343d2 100644 --- a/internal/testingx/dnsoverhttps.go +++ b/internal/testingx/dnsoverhttps.go @@ -4,15 +4,13 @@ import ( "io" "net/http" - "github.com/ooni/netem" "github.com/ooni/probe-cli/v3/internal/runtimex" ) // DNSOverHTTPSHandler is an [http.Handler] implementing DNS-over-HTTPS. type DNSOverHTTPSHandler struct { - // Config is the MANDATORY config telling this DNS server which specific mappings - // between domain names and IP addresses it knows. - Config *netem.DNSConfig + // RoundTripper is the MANDATORY round tripper to use. + RoundTripper DNSRoundTripper } var _ http.Handler = &DNSOverHTTPSHandler{} @@ -21,13 +19,13 @@ var _ http.Handler = &DNSOverHTTPSHandler{} func (p *DNSOverHTTPSHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { defer p.handlePanic(w) rawQuery := runtimex.Try1(io.ReadAll(r.Body)) - rawResponse := runtimex.Try1(netem.DNSServerRoundTrip(p.Config, rawQuery)) + rawResponse := runtimex.Try1(p.RoundTripper.RoundTrip(r.Context(), rawQuery)) w.Header().Add("content-type", "application/dns-message") w.Write(rawResponse) } func (p *DNSOverHTTPSHandler) handlePanic(w http.ResponseWriter) { if r := recover(); r != nil { - w.WriteHeader(500) + w.WriteHeader(http.StatusInternalServerError) } } diff --git a/internal/testingx/dnsoverhttps_test.go b/internal/testingx/dnsoverhttps_test.go index 59bfaf627f..f97c2b07de 100644 --- a/internal/testingx/dnsoverhttps_test.go +++ b/internal/testingx/dnsoverhttps_test.go @@ -2,10 +2,12 @@ package testingx import ( "bytes" + "errors" "io" "net/http" "net/http/httptest" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/miekg/dns" @@ -43,59 +45,113 @@ func TestDNSOverHTTPSHandler(t *testing.T) { 0x00, 0x01, // QCLASS (IN) } - config := netem.NewDNSConfig() - config.AddRecord("example.com", "web01.example.com", "93.184.216.34") - handler := &DNSOverHTTPSHandler{ - Config: config, - } - server := httptest.NewServer(handler) - defer server.Close() - type testconfig struct { name string + newHandler func() http.Handler query []byte expectStatus int expectResponse []byte } testcases := []testconfig{{ - name: "when querying for an existing domain", + name: "when querying for an existing domain", + newHandler: func() http.Handler { + config := netem.NewDNSConfig() + config.AddRecord("example.com", "web01.example.com", "93.184.216.34") + return &DNSOverHTTPSHandler{ + RoundTripper: NewDNSRoundTripperWithDNSConfig(config), + } + }, query: exampleComQuery, expectStatus: 200, expectResponse: []byte{ - 0x00, 0x01, 0x80, 0x00, 0x00, 0x01, 0x00, 0x02, - 0x00, 0x00, 0x00, 0x00, 0x07, 0x65, 0x78, 0x61, - 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, - 0x00, 0x00, 0x01, 0x00, 0x01, 0x07, 0x65, 0x78, - 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, - 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, - 0x0e, 0x10, 0x00, 0x04, 0x5d, 0xb8, 0xd8, 0x22, - 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, - 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x05, 0x00, - 0x01, 0x00, 0x00, 0x0e, 0x10, 0x00, 0x13, 0x05, - 0x77, 0x65, 0x62, 0x30, 0x31, 0x07, 0x65, 0x78, - 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, - 0x6d, 0x00, + 0x00, 0x01, // Transaction ID + 0x80, 0x00, // Flags (response, recursion desired) + 0x00, 0x01, // Num questions + 0x00, 0x02, // Num asnwers RRs + 0x00, 0x00, // Num Authority RRs + 0x00, 0x00, // Num Additional RRs + + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x63, 0x6f, 0x6d, // QNAME: 3(com) + 0x00, // QNAME: null terminator + 0x00, 0x01, // type = A + 0x00, 0x01, // class = IN + + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x63, 0x6f, 0x6d, // QNAME: 3(com) + 0x00, // QNAME: null terminator + 0x00, 0x01, // type = A + 0x00, 0x01, // class = IN + 0x00, 0x00, 0x0e, 0x10, // TTL = 3600 seconds + 0x00, 0x04, // data length: 4 bytes + 0x5d, 0xb8, 0xd8, 0x22, // IPv4 address (93.184.216.34) + + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x63, 0x6f, 0x6d, // QNAME: 3(com) + 0x00, // QNAME: null terminator + 0x00, 0x05, // type = CNAME + 0x00, 0x01, // class = IN + 0x00, 0x00, 0x0e, 0x10, // TTL = 3600 seconds + 0x00, 0x13, // data length = 19 bytes + 0x05, 0x77, 0x65, 0x62, 0x30, 0x31, // QNAME: 5(web01) + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x63, 0x6f, 0x6d, // QNAME: 3(com) + 0x00, // QNAME: null terminator }, }, { - name: "when querying for a nonexisting domain", + name: "when querying for a nonexisting domain", + newHandler: func() http.Handler { + config := netem.NewDNSConfig() + config.AddRecord("example.com", "web01.example.com", "93.184.216.34") + return &DNSOverHTTPSHandler{ + RoundTripper: NewDNSRoundTripperWithDNSConfig(config), + } + }, query: exampleOrgQuery, expectStatus: 200, expectResponse: []byte{ - 0x00, 0x01, 0x80, 0x03, 0x00, 0x01, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x07, 0x65, 0x78, 0x61, - 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x6f, 0x72, 0x67, - 0x00, 0x00, 0x01, 0x00, 0x01, + 0x00, 0x01, // Transaction ID + 0x80, 0x03, // Flags (Response, NXDOMAIN) + 0x00, 0x01, // Num questions + 0x00, 0x00, // Num answers RRs + 0x00, 0x00, // Num authority RRs + 0x00, 0x00, // Num additional RRs + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x6f, 0x72, 0x67, // QNAME: 3(com) + 0x00, // QNAME: null terminator + 0x00, 0x01, // type = A + 0x00, 0x01, // class = IN }, }, { - name: "with invalid query", + name: "with invalid query", + newHandler: func() http.Handler { + config := netem.NewDNSConfig() + config.AddRecord("example.com", "web01.example.com", "93.184.216.34") + return &DNSOverHTTPSHandler{ + RoundTripper: NewDNSRoundTripperWithDNSConfig(config), + } + }, query: []byte{0x22}, expectStatus: 500, expectResponse: []byte{}, + }, { + name: "with internal round trip error", + newHandler: func() http.Handler { + return &DNSOverHTTPSHandler{ + RoundTripper: NewDNSRoundTripperSimulateTimeout(time.Millisecond, errors.New("antani")), + } + }, + query: exampleComQuery, + expectStatus: 500, + expectResponse: []byte{}, }} for _, tc := range testcases { t.Run(tc.name, func(t *testing.T) { + server := httptest.NewServer(tc.newHandler()) + defer server.Close() + req, err := http.NewRequest("POST", server.URL, bytes.NewReader(tc.query)) if err != nil { t.Fatal(err) diff --git a/internal/testingx/dnsoverudp.go b/internal/testingx/dnsoverudp.go new file mode 100644 index 0000000000..b03ab64995 --- /dev/null +++ b/internal/testingx/dnsoverudp.go @@ -0,0 +1,106 @@ +package testingx + +import ( + "context" + "errors" + "net" + "sync" + + "github.com/ooni/probe-cli/v3/internal/runtimex" +) + +// DNSOverUDPUnderlyingListener is the underlying listener used by [DNSOverUDPListener]. +type DNSOverUDPUnderlyingListener interface { + ListenUDP(network string, addr *net.UDPAddr) (net.PacketConn, error) +} + +// DNSOverUDPStdlibListener implements [DNSOverUDPUnderlyingListener] using the standard library. +type DNSOverUDPStdlibListener struct{} + +var _ DNSOverUDPUnderlyingListener = &DNSOverUDPStdlibListener{} + +// ListenUDP implements DNSOverUDPUnderlyingListener. +func (*DNSOverUDPStdlibListener) ListenUDP(network string, addr *net.UDPAddr) (net.PacketConn, error) { + return net.ListenUDP(network, addr) +} + +// DNSOverUDPListener is a DNS-over-UDP listener. The zero value of this +// struct is invalid, please use [NewDNSOverUDPListener]. +type DNSOverUDPListener struct { + cancel context.CancelFunc + closeOnce sync.Once + pconn net.PacketConn + rtx DNSRoundTripper + wg sync.WaitGroup +} + +// MustNewDNSOverUDPListener creates a new [DNSOverUDPListener] using the given +// [DNSOverUDPUnderlyingListener], [DNSRoundTripper], and [*net.UDPAddr]. +func MustNewDNSOverUDPListener(addr *net.UDPAddr, dul DNSOverUDPUnderlyingListener, rtx DNSRoundTripper) *DNSOverUDPListener { + pconn := runtimex.Try1(dul.ListenUDP("udp", addr)) + ctx, cancel := context.WithCancel(context.Background()) + dl := &DNSOverUDPListener{ + cancel: cancel, + closeOnce: sync.Once{}, + pconn: pconn, + rtx: rtx, + wg: sync.WaitGroup{}, + } + dl.wg.Add(1) + go dl.mainloop(ctx) + return dl +} + +// LocalAddr returns the connection address. The return value is nil after you called Close. +func (dl *DNSOverUDPListener) LocalAddr() net.Addr { + return dl.pconn.LocalAddr() +} + +// Close implements io.Closer. +func (dl *DNSOverUDPListener) Close() (err error) { + dl.closeOnce.Do(func() { + // close the connection to interrupt ReadFrom or WriteTo + err = dl.pconn.Close() + + // cancel the context to interrupt the round tripper + dl.cancel() + + // wait for the background goroutine to join + dl.wg.Wait() + }) + return err +} + +func (dl *DNSOverUDPListener) mainloop(ctx context.Context) { + // synchronize with Close + defer dl.wg.Done() + + for { + // read from the socket + buffer := make([]byte, 1<<17) + count, addr, err := dl.pconn.ReadFrom(buffer) + + // handle errors including the case in which we're closed + if errors.Is(err, net.ErrClosed) { + return + } + if err != nil { + continue + } + + // prepare the raw request for the round tripper + rawReq := buffer[:count] + + // perform the round trip + rawResp, err := dl.rtx.RoundTrip(ctx, rawReq) + + // on error, just ignore the message + if err != nil { + continue + } + + // emit the message and ignore any error; we'll notice ErrClosed + // in the next ReadFrom call and stop the loop + _, _ = dl.pconn.WriteTo(rawResp, addr) + } +} diff --git a/internal/testingx/dnsoverudp_test.go b/internal/testingx/dnsoverudp_test.go new file mode 100644 index 0000000000..0717ee21f7 --- /dev/null +++ b/internal/testingx/dnsoverudp_test.go @@ -0,0 +1,297 @@ +package testingx + +import ( + "context" + "errors" + "net" + "os" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/miekg/dns" + "github.com/ooni/netem" + "github.com/ooni/probe-cli/v3/internal/mocks" +) + +func TestDNSOverUDPHandler(t *testing.T) { + exampleComQuery := []byte{ + 0x00, 0x01, // Transaction ID + 0x00, 0x00, // Flags + 0x00, 0x01, // Questions + 0x00, 0x00, // Answer RRs + 0x00, 0x00, // Authority RRs + 0x00, 0x00, // Additional RRs + // QNAME + 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', + 0x03, 'c', 'o', 'm', + 0x00, // Null-terminator of QNAME + 0x00, 0x01, // QTYPE (A record) + 0x00, 0x01, // QCLASS (IN) + } + + exampleOrgQuery := []byte{ + 0x00, 0x01, // Transaction ID + 0x00, 0x00, // Flags + 0x00, 0x01, // Questions + 0x00, 0x00, // Answer RRs + 0x00, 0x00, // Authority RRs + 0x00, 0x00, // Additional RRs + // QNAME + 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', + 0x03, 'o', 'r', 'g', + 0x00, // Null-terminator of QNAME + 0x00, 0x01, // QTYPE (A record) + 0x00, 0x01, // QCLASS (IN) + } + + type testconfig struct { + name string + newRoundTripper func() DNSRoundTripper + query []byte + expectErr error + expectResponse []byte + } + + testcases := []testconfig{{ + name: "when querying for an existing domain", + newRoundTripper: func() DNSRoundTripper { + config := netem.NewDNSConfig() + config.AddRecord("example.com", "web01.example.com", "93.184.216.34") + return NewDNSRoundTripperWithDNSConfig(config) + }, + query: exampleComQuery, + expectErr: nil, + expectResponse: []byte{ + 0x00, 0x01, // Transaction ID + 0x80, 0x00, // Flags (response) + 0x00, 0x01, // Num questions + 0x00, 0x02, // Num asnwers RRs + 0x00, 0x00, // Num Authority RRs + 0x00, 0x00, // Num Additional RRs + + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x63, 0x6f, 0x6d, // QNAME: 3(com) + 0x00, // QNAME: null terminator + 0x00, 0x01, // type = A + 0x00, 0x01, // class = IN + + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x63, 0x6f, 0x6d, // QNAME: 3(com) + 0x00, // QNAME: null terminator + 0x00, 0x01, // type = A + 0x00, 0x01, // class = IN + 0x00, 0x00, 0x0e, 0x10, // TTL = 3600 seconds + 0x00, 0x04, // data length: 4 bytes + 0x5d, 0xb8, 0xd8, 0x22, // IPv4 address (93.184.216.34) + + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x63, 0x6f, 0x6d, // QNAME: 3(com) + 0x00, // QNAME: null terminator + 0x00, 0x05, // type = CNAME + 0x00, 0x01, // class = IN + 0x00, 0x00, 0x0e, 0x10, // TTL = 3600 seconds + 0x00, 0x13, // data length = 19 bytes + 0x05, 0x77, 0x65, 0x62, 0x30, 0x31, // QNAME: 5(web01) + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x63, 0x6f, 0x6d, // QNAME: 3(com) + 0x00, // QNAME: null terminator + }, + }, { + name: "when querying for a nonexisting domain", + newRoundTripper: func() DNSRoundTripper { + config := netem.NewDNSConfig() + config.AddRecord("example.com", "web01.example.com", "93.184.216.34") + return NewDNSRoundTripperWithDNSConfig(config) + }, + query: exampleOrgQuery, + expectErr: nil, + expectResponse: []byte{ + 0x00, 0x01, // Transaction ID + 0x80, 0x03, // Flags (Response, NXDOMAIN) + 0x00, 0x01, // Num questions + 0x00, 0x00, // Num answers RRs + 0x00, 0x00, // Num authority RRs + 0x00, 0x00, // Num additional RRs + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x6f, 0x72, 0x67, // QNAME: 3(com) + 0x00, // QNAME: null terminator + 0x00, 0x01, // type = A + 0x00, 0x01, // class = IN + }, + }, { + name: "with invalid query", + newRoundTripper: func() DNSRoundTripper { + config := netem.NewDNSConfig() + config.AddRecord("example.com", "web01.example.com", "93.184.216.34") + return NewDNSRoundTripperWithDNSConfig(config) + }, + query: []byte{0x22}, + expectErr: os.ErrDeadlineExceeded, + expectResponse: []byte{}, + }, { + name: "with round trip timeout", + newRoundTripper: func() DNSRoundTripper { + return NewDNSRoundTripperSimulateTimeout(time.Millisecond, errors.New("antani")) + }, + query: exampleComQuery, + expectErr: os.ErrDeadlineExceeded, + expectResponse: []byte{}, + }, { + name: "with DNSRoundTripperEmptyResponse and valid query", + newRoundTripper: func() DNSRoundTripper { + return NewDNSRoundTripperEmptyRespnse() + }, + query: exampleComQuery, + expectErr: nil, + expectResponse: []byte{ + 0x00, 0x01, // Transaction ID + 0x80, 0x00, // Flags (response) + 0x00, 0x01, // Num questions + 0x00, 0x00, // Num asnwers RRs + 0x00, 0x00, // Num Authority RRs + 0x00, 0x00, // Num Additional RRs + + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x63, 0x6f, 0x6d, // QNAME: 3(com) + 0x00, // QNAME: null terminator + 0x00, 0x01, // type = A + 0x00, 0x01, // class = IN + }, + }, { + name: "with DNSRoundTripperEmptyResponse and and invalid query", + newRoundTripper: func() DNSRoundTripper { + return NewDNSRoundTripperEmptyRespnse() + }, + query: []byte{0x22}, + expectErr: os.ErrDeadlineExceeded, + expectResponse: []byte{}, + }, { + name: "with DNSRoundTripperRefused and valid query", + newRoundTripper: func() DNSRoundTripper { + return NewDNSRoundTripperRefused() + }, + query: exampleComQuery, + expectErr: nil, + expectResponse: []byte{ + 0x00, 0x01, // Transaction ID + 0x80, 0x05, // Flags (response, refused) + 0x00, 0x01, // Num questions + 0x00, 0x00, // Num asnwers RRs + 0x00, 0x00, // Num Authority RRs + 0x00, 0x00, // Num Additional RRs + + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x63, 0x6f, 0x6d, // QNAME: 3(com) + 0x00, // QNAME: null terminator + 0x00, 0x01, // type = A + 0x00, 0x01, // class = IN + }, + }, { + name: "with DNSRoundTripperEmptyResponse and and invalid query", + newRoundTripper: func() DNSRoundTripper { + return NewDNSRoundTripperRefused() + }, + query: []byte{0x22}, + expectErr: os.ErrDeadlineExceeded, + expectResponse: []byte{}, + }, { + name: "with DNSRoundTripperNXDOMAIN", + newRoundTripper: func() DNSRoundTripper { + return NewDNSRoundTripperNXDOMAIN() + }, + query: exampleOrgQuery, + expectErr: nil, + expectResponse: []byte{ + 0x00, 0x01, // Transaction ID + 0x80, 0x03, // Flags (Response, NXDOMAIN) + 0x00, 0x01, // Num questions + 0x00, 0x00, // Num answers RRs + 0x00, 0x00, // Num authority RRs + 0x00, 0x00, // Num additional RRs + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x6f, 0x72, 0x67, // QNAME: 3(com) + 0x00, // QNAME: null terminator + 0x00, 0x01, // type = A + 0x00, 0x01, // class = IN + }, + }} + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + udpAddr := &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 0, + } + listener := MustNewDNSOverUDPListener(udpAddr, &DNSOverUDPStdlibListener{}, tc.newRoundTripper()) + defer listener.Close() + + pconn, err := net.Dial("udp", listener.LocalAddr().String()) + if err != nil { + t.Fatal(err) + } + pconn.SetDeadline(time.Now().Add(250 * time.Millisecond)) + _, _ = pconn.Write(tc.query) + + buffer := make([]byte, 1<<14) + count, err := pconn.Read(buffer) + + switch { + case tc.expectErr == nil && err != nil: + t.Fatal("expected no error but got", err) + case tc.expectErr != nil && err == nil: + t.Fatal("expected", tc.expectErr, "but got", err) + case tc.expectErr != nil && err != nil: + if !errors.Is(err, tc.expectErr) { + t.Fatal("expected", tc.expectErr, "but got", err) + } + return + default: + // fallthrough + } + + if err != nil { + t.Fatal(err) + } + + rawResponse := buffer[:count] + msg := &dns.Msg{} + if err := msg.Unpack(rawResponse); err != nil { + t.Fatal(err) + } + t.Logf("\n%s", msg) + t.Logf("%#v", rawResponse) + + if diff := cmp.Diff(tc.expectResponse, rawResponse); diff != "" { + t.Fatal(diff) + } + }) + } + + t.Run("when there is an error reading in the main loop", func(t *testing.T) { + called := &atomic.Bool{} + rtx := &DNSOverUDPListener{ + cancel: func() { + // nothing to do here + }, + closeOnce: sync.Once{}, + pconn: &mocks.UDPLikeConn{ + MockReadFrom: func(p []byte) (int, net.Addr, error) { + if called.Load() { + return 0, nil, net.ErrClosed + } + called.Store(true) + return 0, nil, errors.New("mocked error") + }, + }, + rtx: nil, + wg: sync.WaitGroup{}, + } + + rtx.wg.Add(1) + go rtx.mainloop(context.Background()) + rtx.wg.Wait() + }) +} diff --git a/internal/testingx/dnssimulategfw.go b/internal/testingx/dnssimulategfw.go new file mode 100644 index 0000000000..ac4af56ebd --- /dev/null +++ b/internal/testingx/dnssimulategfw.go @@ -0,0 +1,115 @@ +package testingx + +import ( + "errors" + "net" + "sync" + + "github.com/ooni/netem" + "github.com/ooni/probe-cli/v3/internal/runtimex" +) + +// DNSNumBogusResponses is a type indicating the number of bogus responses +// the [DNSSimulateGWFListener] should emit for each round trip. +type DNSNumBogusResponses int + +// DNSSimulateGWFListener is a DNS-over-UDP listener that simulates the GFW behavior by +// responding with N+1 answers, where the first N answers are invalid for the domain +// and the last answer is correct for the domain. The zero value of this struct is +// invalid, please use [NewDNSSimulateGWFListener]. +type DNSSimulateGWFListener struct { + bogusConfig *netem.DNSConfig + closeOnce sync.Once + goodConfig *netem.DNSConfig + numBogus DNSNumBogusResponses + pconn net.PacketConn + wg sync.WaitGroup +} + +// MustNewDNSSimulateGWFListener creates a new [DNSSimulateGWFListener] using the given +// [DNSOverUDPUnderlyingListener], [*net.UDPAddr], and [*netem.DNSConfig]. The bogusConfig +// is used to prepare the bogus responses, and the good config is used to prepare the +// final response containing valid IP addresses for the domain. If numBogusResponses is +// less or equal than 1, we will force its value to be 1. +func MustNewDNSSimulateGWFListener( + addr *net.UDPAddr, + dul DNSOverUDPUnderlyingListener, + bogusConfig *netem.DNSConfig, + goodConfig *netem.DNSConfig, + numBogusResponses DNSNumBogusResponses, +) *DNSSimulateGWFListener { + pconn := runtimex.Try1(dul.ListenUDP("udp", addr)) + if numBogusResponses < 1 { + numBogusResponses = 1 // as documented + } + dl := &DNSSimulateGWFListener{ + bogusConfig: bogusConfig, + closeOnce: sync.Once{}, + goodConfig: goodConfig, + numBogus: numBogusResponses, + pconn: pconn, + wg: sync.WaitGroup{}, + } + dl.wg.Add(1) + go dl.mainloop() + return dl +} + +// LocalAddr returns the connection address. The return value is nil after you called Close. +func (dl *DNSSimulateGWFListener) LocalAddr() net.Addr { + return dl.pconn.LocalAddr() +} + +// Close implements io.Closer. +func (dl *DNSSimulateGWFListener) Close() (err error) { + dl.closeOnce.Do(func() { + // close the connection to interrupt ReadFrom or WriteTo + err = dl.pconn.Close() + + // wait for the background goroutine to join + dl.wg.Wait() + }) + return err +} + +func (dl *DNSSimulateGWFListener) mainloop() { + // synchronize with Close + defer dl.wg.Done() + + for { + // read from the socket + buffer := make([]byte, 1<<17) + count, addr, err := dl.pconn.ReadFrom(buffer) + + // handle errors including the case in which we're closed + if errors.Is(err, net.ErrClosed) { + return + } + if err != nil { + continue + } + + // prepare the raw request for the round tripper + rawReq := buffer[:count] + + // emit N >= 1 bogus responses followed by a valid response + for idx := DNSNumBogusResponses(0); idx < dl.numBogus; idx++ { + dl.writeResponse(addr, dl.bogusConfig, rawReq) + } + dl.writeResponse(addr, dl.goodConfig, rawReq) + } +} + +func (dl *DNSSimulateGWFListener) writeResponse(addr net.Addr, config *netem.DNSConfig, rawReq []byte) { + // perform the round trip + rawResp, err := netem.DNSServerRoundTrip(config, rawReq) + + // on error, just ignore the message + if err != nil { + return + } + + // emit the message and ignore any error; we'll notice ErrClosed + // in the next ReadFrom call and stop the loop + _, _ = dl.pconn.WriteTo(rawResp, addr) +} diff --git a/internal/testingx/dnssimulategfw_test.go b/internal/testingx/dnssimulategfw_test.go new file mode 100644 index 0000000000..6e083d3eec --- /dev/null +++ b/internal/testingx/dnssimulategfw_test.go @@ -0,0 +1,259 @@ +package testingx + +import ( + "errors" + "net" + "os" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/miekg/dns" + "github.com/ooni/netem" + "github.com/ooni/probe-cli/v3/internal/mocks" +) + +func TestDNSSimulateGFW(t *testing.T) { + exampleComQuery := []byte{ + 0x00, 0x01, // Transaction ID + 0x00, 0x00, // Flags + 0x00, 0x01, // Questions + 0x00, 0x00, // Answer RRs + 0x00, 0x00, // Authority RRs + 0x00, 0x00, // Additional RRs + // QNAME + 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', + 0x03, 'c', 'o', 'm', + 0x00, // Null-terminator of QNAME + 0x00, 0x01, // QTYPE (A record) + 0x00, 0x01, // QCLASS (IN) + } + + exampleOrgQuery := []byte{ + 0x00, 0x01, // Transaction ID + 0x00, 0x00, // Flags + 0x00, 0x01, // Questions + 0x00, 0x00, // Answer RRs + 0x00, 0x00, // Authority RRs + 0x00, 0x00, // Additional RRs + // QNAME + 0x07, 'e', 'x', 'a', 'm', 'p', 'l', 'e', + 0x03, 'o', 'r', 'g', + 0x00, // Null-terminator of QNAME + 0x00, 0x01, // QTYPE (A record) + 0x00, 0x01, // QCLASS (IN) + } + + type testconfig struct { + name string + query []byte + expectErr error + expectResponseBogus []byte + expectResponseGood []byte + } + + testcases := []testconfig{{ + name: "when the query is valid", + query: exampleComQuery, + expectErr: nil, + expectResponseBogus: []byte{ + 0x00, 0x01, // Transaction ID + 0x80, 0x00, // Flags (response) + 0x00, 0x01, // Num questions + 0x00, 0x01, // Num asnwers RRs + 0x00, 0x00, // Num Authority RRs + 0x00, 0x00, // Num Additional RRs + + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x63, 0x6f, 0x6d, // QNAME: 3(com) + 0x00, // QNAME: null terminator + 0x00, 0x01, // type = A + 0x00, 0x01, // class = IN + + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x63, 0x6f, 0x6d, // QNAME: 3(com) + 0x00, // QNAME: null terminator + 0x00, 0x01, // type = A + 0x00, 0x01, // class = IN + 0x00, 0x00, 0x0e, 0x10, // TTL = 3600 seconds + 0x00, 0x04, // data length: 4 bytes + 0x0a, 0x0a, 0x22, 0x23, // IPv4 address (10.10.34.35) + }, + expectResponseGood: []byte{ + 0x00, 0x01, // Transaction ID + 0x80, 0x00, // Flags (response) + 0x00, 0x01, // Num questions + 0x00, 0x02, // Num asnwers RRs + 0x00, 0x00, // Num Authority RRs + 0x00, 0x00, // Num Additional RRs + + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x63, 0x6f, 0x6d, // QNAME: 3(com) + 0x00, // QNAME: null terminator + 0x00, 0x01, // type = A + 0x00, 0x01, // class = IN + + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x63, 0x6f, 0x6d, // QNAME: 3(com) + 0x00, // QNAME: null terminator + 0x00, 0x01, // type = A + 0x00, 0x01, // class = IN + 0x00, 0x00, 0x0e, 0x10, // TTL = 3600 seconds + 0x00, 0x04, // data length: 4 bytes + 0x5d, 0xb8, 0xd8, 0x22, // IPv4 address (93.184.216.34) + + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x63, 0x6f, 0x6d, // QNAME: 3(com) + 0x00, // QNAME: null terminator + 0x00, 0x05, // type = CNAME + 0x00, 0x01, // class = IN + 0x00, 0x00, 0x0e, 0x10, // TTL = 3600 seconds + 0x00, 0x13, // data length = 19 bytes + 0x05, 0x77, 0x65, 0x62, 0x30, 0x31, // QNAME: 5(web01) + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x63, 0x6f, 0x6d, // QNAME: 3(com) + 0x00, // QNAME: null terminator + }, + }, { + name: "when querying for a nonexisting domain", + query: exampleOrgQuery, + expectErr: nil, + expectResponseBogus: []byte{ + 0x00, 0x01, // Transaction ID + 0x80, 0x03, // Flags (Response, NXDOMAIN) + 0x00, 0x01, // Num questions + 0x00, 0x00, // Num answers RRs + 0x00, 0x00, // Num authority RRs + 0x00, 0x00, // Num additional RRs + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x6f, 0x72, 0x67, // QNAME: 3(com) + 0x00, // QNAME: null terminator + 0x00, 0x01, // type = A + 0x00, 0x01, // class = IN + }, + expectResponseGood: []byte{ + 0x00, 0x01, // Transaction ID + 0x80, 0x03, // Flags (Response, NXDOMAIN) + 0x00, 0x01, // Num questions + 0x00, 0x00, // Num answers RRs + 0x00, 0x00, // Num authority RRs + 0x00, 0x00, // Num additional RRs + 0x07, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, // QNAME: 7(example) + 0x03, 0x6f, 0x72, 0x67, // QNAME: 3(com) + 0x00, // QNAME: null terminator + 0x00, 0x01, // type = A + 0x00, 0x01, // class = IN + }, + }, { + name: "with invalid query", + query: []byte{0x22}, + expectErr: os.ErrDeadlineExceeded, + expectResponseBogus: []byte{}, + expectResponseGood: []byte{}, + }} + + for _, tc := range testcases { + t.Run(tc.name, func(t *testing.T) { + bogusConfig := netem.NewDNSConfig() + bogusConfig.AddRecord("example.com", "", "10.10.34.35") + goodConfig := netem.NewDNSConfig() + goodConfig.AddRecord("example.com", "web01.example.com", "93.184.216.34") + + udpAddr := &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 0, + } + listener := MustNewDNSSimulateGWFListener( + udpAddr, &DNSOverUDPStdlibListener{}, bogusConfig, + goodConfig, DNSNumBogusResponses(2)) + defer listener.Close() + + pconn, err := net.Dial("udp", listener.LocalAddr().String()) + if err != nil { + t.Fatal(err) + } + pconn.SetDeadline(time.Now().Add(250 * time.Millisecond)) + _, _ = pconn.Write(tc.query) + + for idx := 0; idx < 3; idx++ { + buffer := make([]byte, 1<<14) + count, err := pconn.Read(buffer) + + switch { + case tc.expectErr == nil && err != nil: + t.Fatal("expected no error but got", err) + case tc.expectErr != nil && err == nil: + t.Fatal("expected", tc.expectErr, "but got", err) + case tc.expectErr != nil && err != nil: + if !errors.Is(err, tc.expectErr) { + t.Fatal("expected", tc.expectErr, "but got", err) + } + return + default: + // fallthrough + } + + if err != nil { + t.Fatal(err) + } + + rawResponse := buffer[:count] + msg := &dns.Msg{} + if err := msg.Unpack(rawResponse); err != nil { + t.Fatal(err) + } + t.Logf("\n%s", msg) + t.Logf("%#v", rawResponse) + + expectedResp := tc.expectResponseBogus + if idx == 2 { + expectedResp = tc.expectResponseGood + } + + if diff := cmp.Diff(expectedResp, rawResponse); diff != "" { + t.Fatal(diff) + } + } + }) + } + + t.Run("when there is an error reading in the main loop", func(t *testing.T) { + called := &atomic.Bool{} + rtx := &DNSSimulateGWFListener{ + bogusConfig: netem.NewDNSConfig(), + closeOnce: sync.Once{}, + goodConfig: netem.NewDNSConfig(), + pconn: &mocks.UDPLikeConn{MockReadFrom: func(p []byte) (int, net.Addr, error) { + if called.Load() { + return 0, nil, net.ErrClosed + } + called.Store(true) + return 0, nil, errors.New("mocked error") + }}, + wg: sync.WaitGroup{}, + } + + rtx.wg.Add(1) + go rtx.mainloop() + rtx.wg.Wait() + }) + + t.Run("the constructor forces the NumBogusResponses to be 1 when < 1", func(t *testing.T) { + rtx := MustNewDNSSimulateGWFListener( + &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 0, + }, + &DNSOverUDPStdlibListener{}, + netem.NewDNSConfig(), + netem.NewDNSConfig(), + DNSNumBogusResponses(0), + ) + defer rtx.Close() + if rtx.numBogus != 1 { + t.Fatal("expected to see rtx.numBogus == 1, found", rtx.numBogus) + } + }) +}