Skip to content

Commit

Permalink
resolve review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ameshkov committed Apr 27, 2024
1 parent 075a44c commit 5ec9207
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 25 deletions.
2 changes: 1 addition & 1 deletion upstream/doq.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func (p *dnsOverQUIC) Exchange(m *dns.Msg) (resp *dns.Msg, err error) {
// attempt.
conn, _, err = p.getConnection()
if err != nil {
return nil, err
return nil, fmt.Errorf("getting new conn: %w", err)
}

// Retry sending the request through the new connection.
Expand Down
39 changes: 16 additions & 23 deletions upstream/doq_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ func TestUpstreamDoQ(t *testing.T) {
tlsConf, rootCAs := createServerTLSConfig(t, "127.0.0.1")

srv := startDoQServer(t, tlsConf, 0)
testutil.CleanupAndRequireSuccess(t, srv.Shutdown)

address := fmt.Sprintf("quic://%s", srv.addr)
var lastState tls.ConnectionState
Expand All @@ -38,8 +37,7 @@ func TestUpstreamDoQ(t *testing.T) {

return nil
},
RootCAs: rootCAs,
InsecureSkipVerify: true,
RootCAs: rootCAs,
}
u, err := AddressToUpstream(address, opts)
require.NoError(t, err)
Expand Down Expand Up @@ -78,7 +76,7 @@ func TestUpstreamDoQ(t *testing.T) {
checkRaceCondition(u)
}

func TestUpstreamDoQ_serverCloseConn(t *testing.T) {
func TestUpstream_Exchange_quicServerCloseConn(t *testing.T) {
// Use the same tlsConf for all servers to preserve the data necessary for
// 0-RTT connections.
tlsConf, rootCAs := createServerTLSConfig(t, "127.0.0.1")
Expand All @@ -88,11 +86,7 @@ func TestUpstreamDoQ_serverCloseConn(t *testing.T) {

// Create a DNS-over-QUIC upstream.
address := fmt.Sprintf("quic://%s", srv.addr)
u, err := AddressToUpstream(address, &Options{
InsecureSkipVerify: true,
Timeout: 250 * time.Millisecond,
RootCAs: rootCAs,
})
u, err := AddressToUpstream(address, &Options{RootCAs: rootCAs})

require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
Expand All @@ -103,23 +97,28 @@ func TestUpstreamDoQ_serverCloseConn(t *testing.T) {
// Close all active connections.
srv.closeConns()

// Now run several queries in parallel to trigger the error from
// Now run several queries in parallel to check that the error from the
// following issue is not happening:
// https://github.com/AdguardTeam/dnsproxy/issues/389.
//
// Run 10 queries in parallel as the initial testing showed that this is
// enough to trigger the race issue.
const parallelQueries = 10

wg := sync.WaitGroup{}
wg.Add(10)
wg.Add(parallelQueries)

for i := 0; i < 10; i++ {
go func() {
t.Helper()
pt := testutil.PanicT{}

go func(t assert.TestingT) {
defer wg.Done()

req := createTestMessage()
_, uErr := u.Exchange(req)

assert.NoError(t, uErr)
}()
}(pt)
}

wg.Wait()
Expand All @@ -144,11 +143,7 @@ func TestUpstreamDoQ_serverRestart(t *testing.T) {
}).String()

var err error
u, err = AddressToUpstream(upsStr, &Options{
InsecureSkipVerify: true,
Timeout: 250 * time.Millisecond,
RootCAs: rootCAs,
})
u, err = AddressToUpstream(upsStr, &Options{RootCAs: rootCAs})
require.NoError(t, err)

checkUpstream(t, u, upsStr)
Expand Down Expand Up @@ -177,14 +172,12 @@ func TestUpstreamDoQ_0RTT(t *testing.T) {
tlsConf, rootCAs := createServerTLSConfig(t, "127.0.0.1")

srv := startDoQServer(t, tlsConf, 0)
testutil.CleanupAndRequireSuccess(t, srv.Shutdown)

tracer := &quicTracer{}
address := fmt.Sprintf("quic://%s", srv.addr)
u, err := AddressToUpstream(address, &Options{
InsecureSkipVerify: true,
QUICTracer: tracer.TracerForConnection,
RootCAs: rootCAs,
QUICTracer: tracer.TracerForConnection,
RootCAs: rootCAs,
})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, u.Close)
Expand Down
8 changes: 7 additions & 1 deletion upstream/upstream_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,13 @@ func createServerTLSConfig(
BasicConstraintsValid: true,
IsCA: true,
}
template.DNSNames = append(template.DNSNames, tlsServerName)

ipAddress := net.ParseIP(tlsServerName)
if ipAddress != nil {
template.IPAddresses = append(template.IPAddresses, ipAddress)
} else {
template.DNSNames = append(template.DNSNames, tlsServerName)
}

derBytes, err := x509.CreateCertificate(
rand.Reader,
Expand Down

0 comments on commit 5ec9207

Please sign in to comment.