Skip to content

Commit

Permalink
Merge branch 'master' into AGDNS-1982-fix-before-handler
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Apr 9, 2024
2 parents bb485bd + 0368683 commit bc4e3c6
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 25 deletions.
13 changes: 11 additions & 2 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,16 @@ type Proxy struct {
// quicListen are the listened QUIC connections.
quicListen []*quic.EarlyListener

// quicConns are UDP connections for all listened QUIC connections. Those
// should be closed on shutdown, since *quic.EarlyListener doesn't close it.
// quicConns are UDP connections for all listened QUIC connections. These
// should be closed on shutdown, since *quic.EarlyListener doesn't close
// them.
quicConns []*net.UDPConn

// quicTransports are transports for all listened QUIC connections. These
// should be closed on shutdown, since *quic.EarlyListener doesn't close
// them.
quicTransports []*quic.Transport

// httpsListen are the listened HTTPS connections.
httpsListen []net.Listener

Expand Down Expand Up @@ -422,6 +428,9 @@ func (p *Proxy) Shutdown(_ context.Context) (err error) {
errs = closeAll(errs, p.quicListen...)
p.quicListen = nil

errs = closeAll(errs, p.quicTransports...)
p.quicTransports = nil

errs = closeAll(errs, p.quicConns...)
p.quicConns = nil

Expand Down
1 change: 1 addition & 0 deletions proxy/server_quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func (p *Proxy) createQUICListeners() error {
return fmt.Errorf("quic listener: %w", err)
}

p.quicTransports = append(p.quicTransports, transport)
p.quicListen = append(p.quicListen, quicListen)

log.Info("listening quic://%s", quicListen.Addr())
Expand Down
38 changes: 23 additions & 15 deletions upstream/doh_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ func TestUpstreamDoH(t *testing.T) {
delayHandshakeH2: tc.delayHandshakeH2,
delayHandshakeH3: tc.delayHandshakeH3,
})
t.Cleanup(srv.Shutdown)

// Create a DNS-over-HTTPS upstream.
address := fmt.Sprintf("https://%s/dns-query", srv.addr)
Expand Down Expand Up @@ -175,7 +174,6 @@ func TestUpstreamDoH_raceReconnect(t *testing.T) {
delayHandshakeH3: tc.delayHandshakeH3,
handler: mux,
})
t.Cleanup(srv.Shutdown)

// Create a DNS-over-HTTPS upstream that will be used for the
// race test.
Expand Down Expand Up @@ -216,7 +214,6 @@ func TestUpstreamDoH_serverRestart(t *testing.T) {
srv := startDoHServer(t, testDoHServerOptions{
http3Enabled: true,
})
t.Cleanup(srv.Shutdown)

addr = netip.MustParseAddrPort(srv.addr)
upsAddr = (&url.URL{
Expand All @@ -239,11 +236,10 @@ func TestUpstreamDoH_serverRestart(t *testing.T) {
testutil.CleanupAndRequireSuccess(t, u.Close)

t.Run("second_try", func(t *testing.T) {
srv := startDoHServer(t, testDoHServerOptions{
_ = startDoHServer(t, testDoHServerOptions{
http3Enabled: true,
port: int(addr.Port()),
})
t.Cleanup(srv.Shutdown)

checkUpstream(t, u, upsAddr)
})
Expand All @@ -253,11 +249,10 @@ func TestUpstreamDoH_serverRestart(t *testing.T) {
_, err := u.Exchange(createTestMessage())
require.Error(t, err)

srv := startDoHServer(t, testDoHServerOptions{
_ = startDoHServer(t, testDoHServerOptions{
http3Enabled: true,
port: int(addr.Port()),
})
t.Cleanup(srv.Shutdown)

checkUpstream(t, u, upsAddr)
})
Expand All @@ -270,7 +265,6 @@ func TestUpstreamDoH_0RTT(t *testing.T) {
srv := startDoHServer(t, testDoHServerOptions{
http3Enabled: true,
})
t.Cleanup(srv.Shutdown)

// Create a DNS-over-HTTPS upstream.
tracer := &quicTracer{}
Expand Down Expand Up @@ -319,11 +313,21 @@ func TestUpstreamDoH_0RTT(t *testing.T) {

// testDoHServerOptions allows customizing testDoHServer behavior.
type testDoHServerOptions struct {
handler http.Handler
// handler is an HTTP handler that should be used by the server. The
// default one is used on nil.
handler http.Handler
// delayHandshakeH2 is a delay that should be added to the handshake of the
// HTTP/2 server.
delayHandshakeH2 time.Duration
// delayHandshakeH3 is a delay that should be added to the handshake of the
// HTTP/3 server.
delayHandshakeH3 time.Duration
port int
http3Enabled bool
// port is the port that the server should listen to. If it's 0, a random
// port is used.
port int
// http3Enabled is a flag that indicates whether the server should start an
// HTTP/3 server.
http3Enabled bool
}

// testDoHServer is an instance of a test DNS-over-HTTPS server.
Expand Down Expand Up @@ -359,9 +363,9 @@ func (s *testDoHServer) Shutdown() {
}
}

// startDoHServer starts a new DNS-over-HTTPS server on a random port and
// returns the instance of this server. Depending on whether http3Enabled is
// set to true or false it will or will not initialize a HTTP/3 server.
// startDoHServer starts a new DNS-over-HTTPS server with specified options. It
// returns a started server instance with addr set. Note that it adds its own
// shutdown to cleanup of t.
func startDoHServer(
t *testing.T,
opts testDoHServerOptions,
Expand Down Expand Up @@ -444,6 +448,7 @@ func startDoHServer(
Allow0RTT: true,
})
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, transport.Close)

// Run the H3 server.
go func() {
Expand All @@ -452,7 +457,7 @@ func startDoHServer(
}()
}

return &testDoHServer{
s = &testDoHServer{
tlsConfig: tlsConfig,
rootCAs: rootCAs,
server: server,
Expand All @@ -461,6 +466,9 @@ func startDoHServer(
// Save the address that the server listens to.
addr: tcpAddr.String(),
}
t.Cleanup(s.Shutdown)

return s
}

// createDoHHandlerFunc creates a simple http.HandlerFunc that reads the
Expand Down
14 changes: 6 additions & 8 deletions upstream/doq_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ func TestUpstreamDoQ(t *testing.T) {
tlsConf, rootCAs := createServerTLSConfig(t, "127.0.0.1")

srv := startDoQServer(t, tlsConf, 0)
testutil.CleanupAndRequireSuccess(t, srv.Shutdown)

address := fmt.Sprintf("quic://%s", srv.addr)
var lastState tls.ConnectionState
Expand Down Expand Up @@ -88,7 +87,6 @@ func TestUpstreamDoQ_serverRestart(t *testing.T) {

t.Run("first_try", func(t *testing.T) {
srv := startDoQServer(t, tlsConf, 0)
testutil.CleanupAndRequireSuccess(t, srv.Shutdown)

addr = netip.MustParseAddrPort(srv.addr)
upsStr = (&url.URL{
Expand All @@ -110,8 +108,7 @@ func TestUpstreamDoQ_serverRestart(t *testing.T) {
testutil.CleanupAndRequireSuccess(t, u.Close)

t.Run("second_try", func(t *testing.T) {
srv := startDoQServer(t, tlsConf, int(addr.Port()))
testutil.CleanupAndRequireSuccess(t, srv.Shutdown)
_ = startDoQServer(t, tlsConf, int(addr.Port()))

checkUpstream(t, u, upsStr)
})
Expand All @@ -121,8 +118,7 @@ func TestUpstreamDoQ_serverRestart(t *testing.T) {
_, err := u.Exchange(createTestMessage())
require.Error(t, err)

srv := startDoQServer(t, tlsConf, int(addr.Port()))
testutil.CleanupAndRequireSuccess(t, srv.Shutdown)
_ = startDoQServer(t, tlsConf, int(addr.Port()))

checkUpstream(t, u, upsStr)
})
Expand All @@ -132,7 +128,6 @@ func TestUpstreamDoQ_0RTT(t *testing.T) {
tlsConf, rootCAs := createServerTLSConfig(t, "127.0.0.1")

srv := startDoQServer(t, tlsConf, 0)
testutil.CleanupAndRequireSuccess(t, srv.Shutdown)

tracer := &quicTracer{}
address := fmt.Sprintf("quic://%s", srv.addr)
Expand Down Expand Up @@ -271,7 +266,8 @@ func (s *testDoQServer) handleQUICStream(stream quic.Stream) (err error) {
return err
}

// startDoQServer starts a test DoQ server.
// startDoQServer starts a test DoQ server. Note that it adds its own shutdown
// to cleanup of t.
func startDoQServer(t *testing.T, tlsConf *tls.Config, port int) (s *testDoQServer) {
tlsConf.NextProtos = []string{NextProtoDQ}

Expand All @@ -297,13 +293,15 @@ func startDoQServer(t *testing.T, tlsConf *tls.Config, port int) (s *testDoQServ
},
)
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, transport.Close)

s = &testDoQServer{
addr: listen.Addr().String(),
listener: listen,
}

go s.Serve()
testutil.CleanupAndRequireSuccess(t, s.Shutdown)

return s
}
Expand Down

0 comments on commit bc4e3c6

Please sign in to comment.