diff --git a/internal/dnsforward/beforerequest.go b/internal/dnsforward/beforerequest.go index 75c64cecfc7..a1232ce7ed4 100644 --- a/internal/dnsforward/beforerequest.go +++ b/internal/dnsforward/beforerequest.go @@ -17,8 +17,6 @@ var _ proxy.BeforeRequestHandler = (*Server)(nil) // HandleBefore is the handler that is called before any other processing, // including logs. It performs access checks and puts the client ID, if there // is one, into the server's cache. -// -// TODO(e.burkov): Write tests. func (s *Server) HandleBefore( _ *proxy.Proxy, pctx *proxy.DNSContext, diff --git a/internal/dnsforward/dnsforward_test.go b/internal/dnsforward/dnsforward_test.go index 12767cc1876..05d7abdf21d 100644 --- a/internal/dnsforward/dnsforward_test.go +++ b/internal/dnsforward/dnsforward_test.go @@ -1652,3 +1652,112 @@ func TestServer_Exchange(t *testing.T) { assert.Empty(t, host) }) } + +func TestServer_HandleBefore(t *testing.T) { + const ( + blockedHost = "blockedhost.org" + clientID = "client-1" + testHost = "example.org." + ) + + testCases := []struct { + want assert.ValueAssertionFunc + clientSrvName string + name string + host string + allowedClients []string + disallowedClients []string + blockedHosts []string + }{{ + want: assert.NotEmpty, + clientSrvName: tlsServerName, + name: "allow_all", + host: testHost, + allowedClients: []string{}, + disallowedClients: []string{}, + blockedHosts: []string{}, + }, { + want: assert.NotEmpty, + clientSrvName: clientID + "." + tlsServerName, + name: "allowed_client_allowed", + host: testHost, + allowedClients: []string{clientID}, + disallowedClients: []string{}, + blockedHosts: []string{}, + }, { + want: assert.Empty, + clientSrvName: "client-2." + tlsServerName, + name: "allowed_client_rejected", + host: testHost, + allowedClients: []string{clientID}, + disallowedClients: []string{}, + blockedHosts: []string{}, + }, { + want: assert.NotEmpty, + clientSrvName: tlsServerName, + name: "disallowed_client_allowed", + host: testHost, + allowedClients: []string{}, + disallowedClients: []string{clientID}, + blockedHosts: []string{}, + }, { + want: assert.Empty, + clientSrvName: clientID + "." + tlsServerName, + name: "disallowed_client_rejected", + host: testHost, + allowedClients: []string{}, + disallowedClients: []string{clientID}, + blockedHosts: []string{}, + }, { + want: assert.NotEmpty, + clientSrvName: tlsServerName, + name: "blocked_hosts_allowed", + host: testHost, + allowedClients: []string{}, + disallowedClients: []string{}, + blockedHosts: []string{blockedHost}, + }, { + want: assert.Empty, + clientSrvName: tlsServerName, + name: "blocked_hosts_rejected", + host: dns.Fqdn(blockedHost), + allowedClients: []string{}, + disallowedClients: []string{}, + blockedHosts: []string{blockedHost}, + }} + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + s, _ := createTestTLS(t, TLSConfig{ + TLSListenAddrs: []*net.TCPAddr{{}}, + ServerName: tlsServerName, + }) + + s.conf.AllowedClients = tc.allowedClients + s.conf.DisallowedClients = tc.disallowedClients + s.conf.BlockedHosts = tc.blockedHosts + + err := s.Prepare(&s.conf) + require.NoErrorf(t, err, "failed to prepare server: %s", err) + + startDeferStop(t, s) + + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + ServerName: tc.clientSrvName, + } + + client := &dns.Client{ + Net: "tcp-tls", + TLSConfig: tlsConfig, + } + + req := createTestMessage(tc.host) + addr := s.dnsProxy.Addr(proxy.ProtoTLS).String() + + reply, _, err := client.Exchange(req, addr) + require.NoErrorf(t, err, "couldn't talk to server %s: %s", addr, err) + tc.want(t, reply.Answer) + }) + } +}