From 8d95b8fd7b23834eeb3241f16e9177b5d4d862ba Mon Sep 17 00:00:00 2001 From: naiqianz Date: Fri, 12 Jul 2024 00:43:22 +0800 Subject: [PATCH 1/2] add test for tls connCheck MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修改单元测试 测试复现tls问题 尝试修复tls.conn问题 --- internal/pool/conn_check.go | 5 +++++ internal/pool/conn_check_test.go | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/internal/pool/conn_check.go b/internal/pool/conn_check.go index 83190d394..d5e435844 100644 --- a/internal/pool/conn_check.go +++ b/internal/pool/conn_check.go @@ -3,6 +3,7 @@ package pool import ( + "crypto/tls" "errors" "io" "net" @@ -16,6 +17,10 @@ func connCheck(conn net.Conn) error { // Reset previous timeout. _ = conn.SetDeadline(time.Time{}) + // Check if tls.Conn. + if _, ok := conn.(*tls.Conn); ok { + conn = conn.(*tls.Conn).NetConn() + } sysConn, ok := conn.(syscall.Conn) if !ok { return nil diff --git a/internal/pool/conn_check_test.go b/internal/pool/conn_check_test.go index 2ade8a0b9..214993339 100644 --- a/internal/pool/conn_check_test.go +++ b/internal/pool/conn_check_test.go @@ -3,6 +3,7 @@ package pool import ( + "crypto/tls" "net" "net/http/httptest" "time" @@ -14,12 +15,17 @@ import ( var _ = Describe("tests conn_check with real conns", func() { var ts *httptest.Server var conn net.Conn + var tlsConn *tls.Conn var err error BeforeEach(func() { ts = httptest.NewServer(nil) conn, err = net.DialTimeout(ts.Listener.Addr().Network(), ts.Listener.Addr().String(), time.Second) Expect(err).NotTo(HaveOccurred()) + tlsTestServer := httptest.NewUnstartedServer(nil) + tlsTestServer.StartTLS() + tlsConn, err = tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, tlsTestServer.Listener.Addr().Network(), tlsTestServer.Listener.Addr().String(), &tls.Config{InsecureSkipVerify: true}) + Expect(err).NotTo(HaveOccurred()) }) AfterEach(func() { @@ -33,11 +39,23 @@ var _ = Describe("tests conn_check with real conns", func() { Expect(connCheck(conn)).To(HaveOccurred()) }) + It("good tls conn check", func() { + Expect(connCheck(tlsConn)).NotTo(HaveOccurred()) + + Expect(tlsConn.Close()).NotTo(HaveOccurred()) + Expect(connCheck(tlsConn)).To(HaveOccurred()) + }) + It("bad conn check", func() { Expect(conn.Close()).NotTo(HaveOccurred()) Expect(connCheck(conn)).To(HaveOccurred()) }) + It("bad tls conn check", func() { + Expect(tlsConn.Close()).NotTo(HaveOccurred()) + Expect(connCheck(tlsConn)).To(HaveOccurred()) + }) + It("check conn deadline", func() { Expect(conn.SetDeadline(time.Now())).NotTo(HaveOccurred()) time.Sleep(time.Millisecond * 10) From 24ae4a8d56ecfd1effd34c76c453cda0acade590 Mon Sep 17 00:00:00 2001 From: naiqianz Date: Fri, 12 Jul 2024 10:58:35 +0800 Subject: [PATCH 2/2] optimize tls assert --- internal/pool/conn_check.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/pool/conn_check.go b/internal/pool/conn_check.go index d5e435844..07c261c2b 100644 --- a/internal/pool/conn_check.go +++ b/internal/pool/conn_check.go @@ -18,8 +18,8 @@ func connCheck(conn net.Conn) error { _ = conn.SetDeadline(time.Time{}) // Check if tls.Conn. - if _, ok := conn.(*tls.Conn); ok { - conn = conn.(*tls.Conn).NetConn() + if c, ok := conn.(*tls.Conn); ok { + conn = c.NetConn() } sysConn, ok := conn.(syscall.Conn) if !ok {