diff --git a/internal/pool/conn_check.go b/internal/pool/conn_check.go index 83190d394..07c261c2b 100644 --- a/internal/pool/conn_check.go +++ b/internal/pool/conn_check.go @@ -3,6 +3,7 @@ package pool import ( + "crypto/tls" "errors" "io" "net" @@ -16,6 +17,10 @@ func connCheck(conn net.Conn) error { // Reset previous timeout. _ = conn.SetDeadline(time.Time{}) + // Check if tls.Conn. + if c, ok := conn.(*tls.Conn); ok { + conn = c.NetConn() + } sysConn, ok := conn.(syscall.Conn) if !ok { return nil diff --git a/internal/pool/conn_check_test.go b/internal/pool/conn_check_test.go index 2ade8a0b9..214993339 100644 --- a/internal/pool/conn_check_test.go +++ b/internal/pool/conn_check_test.go @@ -3,6 +3,7 @@ package pool import ( + "crypto/tls" "net" "net/http/httptest" "time" @@ -14,12 +15,17 @@ import ( var _ = Describe("tests conn_check with real conns", func() { var ts *httptest.Server var conn net.Conn + var tlsConn *tls.Conn var err error BeforeEach(func() { ts = httptest.NewServer(nil) conn, err = net.DialTimeout(ts.Listener.Addr().Network(), ts.Listener.Addr().String(), time.Second) Expect(err).NotTo(HaveOccurred()) + tlsTestServer := httptest.NewUnstartedServer(nil) + tlsTestServer.StartTLS() + tlsConn, err = tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, tlsTestServer.Listener.Addr().Network(), tlsTestServer.Listener.Addr().String(), &tls.Config{InsecureSkipVerify: true}) + Expect(err).NotTo(HaveOccurred()) }) AfterEach(func() { @@ -33,11 +39,23 @@ var _ = Describe("tests conn_check with real conns", func() { Expect(connCheck(conn)).To(HaveOccurred()) }) + It("good tls conn check", func() { + Expect(connCheck(tlsConn)).NotTo(HaveOccurred()) + + Expect(tlsConn.Close()).NotTo(HaveOccurred()) + Expect(connCheck(tlsConn)).To(HaveOccurred()) + }) + It("bad conn check", func() { Expect(conn.Close()).NotTo(HaveOccurred()) Expect(connCheck(conn)).To(HaveOccurred()) }) + It("bad tls conn check", func() { + Expect(tlsConn.Close()).NotTo(HaveOccurred()) + Expect(connCheck(tlsConn)).To(HaveOccurred()) + }) + It("check conn deadline", func() { Expect(conn.SetDeadline(time.Now())).NotTo(HaveOccurred()) time.Sleep(time.Millisecond * 10)