From 7565fdc50a94599f6163941c23829b8a7f724ea2 Mon Sep 17 00:00:00 2001
From: monkey92t <golang@88.com>
Date: Thu, 18 Jul 2024 23:06:57 +0800
Subject: [PATCH 1/2] feat: reduce the type assertion of CheckConn

Signed-off-by: monkey92t <golang@88.com>
---
 internal/pool/conn.go             | 22 ++++++++++++++++++++++
 internal/pool/conn_check.go       | 16 +---------------
 internal/pool/conn_check_dummy.go |  4 ++--
 internal/pool/conn_check_test.go  | 23 ++++++++++++++++-------
 internal/pool/pool.go             | 10 ++++++++--
 5 files changed, 49 insertions(+), 26 deletions(-)

diff --git a/internal/pool/conn.go b/internal/pool/conn.go
index 7f45bc0bb..cc86ca251 100644
--- a/internal/pool/conn.go
+++ b/internal/pool/conn.go
@@ -3,8 +3,10 @@ package pool
 import (
 	"bufio"
 	"context"
+	"crypto/tls"
 	"net"
 	"sync/atomic"
+	"syscall"
 	"time"
 
 	"github.com/redis/go-redis/v9/internal/proto"
@@ -16,6 +18,9 @@ type Conn struct {
 	usedAt  int64 // atomic
 	netConn net.Conn
 
+	// for checking the health status of the connection, it may be nil.
+	sysConn syscall.Conn
+
 	rd *proto.Reader
 	bw *bufio.Writer
 	wr *proto.Writer
@@ -34,6 +39,7 @@ func NewConn(netConn net.Conn) *Conn {
 	cn.bw = bufio.NewWriter(netConn)
 	cn.wr = proto.NewWriter(cn.bw)
 	cn.SetUsedAt(time.Now())
+	cn.setRawConn()
 	return cn
 }
 
@@ -50,6 +56,22 @@ func (cn *Conn) SetNetConn(netConn net.Conn) {
 	cn.netConn = netConn
 	cn.rd.Reset(netConn)
 	cn.bw.Reset(netConn)
+	cn.setRawConn()
+}
+
+func (cn *Conn) setRawConn() {
+	cn.sysConn = nil
+	conn := cn.netConn
+	if conn == nil {
+		return
+	}
+	if tlsConn, ok := conn.(*tls.Conn); ok {
+		conn = tlsConn.NetConn()
+	}
+
+	if sysConn, ok := conn.(syscall.Conn); ok {
+		cn.sysConn = sysConn
+	}
 }
 
 func (cn *Conn) Write(b []byte) (int, error) {
diff --git a/internal/pool/conn_check.go b/internal/pool/conn_check.go
index 07c261c2b..f28833850 100644
--- a/internal/pool/conn_check.go
+++ b/internal/pool/conn_check.go
@@ -3,28 +3,14 @@
 package pool
 
 import (
-	"crypto/tls"
 	"errors"
 	"io"
-	"net"
 	"syscall"
-	"time"
 )
 
 var errUnexpectedRead = errors.New("unexpected read from socket")
 
-func connCheck(conn net.Conn) error {
-	// Reset previous timeout.
-	_ = conn.SetDeadline(time.Time{})
-
-	// Check if tls.Conn.
-	if c, ok := conn.(*tls.Conn); ok {
-		conn = c.NetConn()
-	}
-	sysConn, ok := conn.(syscall.Conn)
-	if !ok {
-		return nil
-	}
+func connCheck(sysConn syscall.Conn) error {
 	rawConn, err := sysConn.SyscallConn()
 	if err != nil {
 		return err
diff --git a/internal/pool/conn_check_dummy.go b/internal/pool/conn_check_dummy.go
index 295da1268..2d270cf56 100644
--- a/internal/pool/conn_check_dummy.go
+++ b/internal/pool/conn_check_dummy.go
@@ -2,8 +2,8 @@
 
 package pool
 
-import "net"
+import "syscall"
 
-func connCheck(conn net.Conn) error {
+func connCheck(_ syscall.Conn) error {
 	return nil
 }
diff --git a/internal/pool/conn_check_test.go b/internal/pool/conn_check_test.go
index 214993339..d19969adf 100644
--- a/internal/pool/conn_check_test.go
+++ b/internal/pool/conn_check_test.go
@@ -6,6 +6,7 @@ import (
 	"crypto/tls"
 	"net"
 	"net/http/httptest"
+	"syscall"
 	"time"
 
 	. "github.com/bsm/ginkgo/v2"
@@ -16,16 +17,20 @@ var _ = Describe("tests conn_check with real conns", func() {
 	var ts *httptest.Server
 	var conn net.Conn
 	var tlsConn *tls.Conn
+	var sysConn syscall.Conn
+	var tlsSysConn syscall.Conn
 	var err error
 
 	BeforeEach(func() {
 		ts = httptest.NewServer(nil)
 		conn, err = net.DialTimeout(ts.Listener.Addr().Network(), ts.Listener.Addr().String(), time.Second)
 		Expect(err).NotTo(HaveOccurred())
+		sysConn = conn.(syscall.Conn)
 		tlsTestServer := httptest.NewUnstartedServer(nil)
 		tlsTestServer.StartTLS()
 		tlsConn, err = tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, tlsTestServer.Listener.Addr().Network(), tlsTestServer.Listener.Addr().String(), &tls.Config{InsecureSkipVerify: true})
 		Expect(err).NotTo(HaveOccurred())
+		tlsSysConn = tlsConn.NetConn().(syscall.Conn)
 	})
 
 	AfterEach(func() {
@@ -33,33 +38,37 @@ var _ = Describe("tests conn_check with real conns", func() {
 	})
 
 	It("good conn check", func() {
-		Expect(connCheck(conn)).NotTo(HaveOccurred())
+		Expect(connCheck(sysConn)).NotTo(HaveOccurred())
 
 		Expect(conn.Close()).NotTo(HaveOccurred())
-		Expect(connCheck(conn)).To(HaveOccurred())
+		Expect(connCheck(sysConn)).To(HaveOccurred())
 	})
 
 	It("good tls conn check", func() {
-		Expect(connCheck(tlsConn)).NotTo(HaveOccurred())
+		Expect(connCheck(tlsSysConn)).NotTo(HaveOccurred())
 
 		Expect(tlsConn.Close()).NotTo(HaveOccurred())
-		Expect(connCheck(tlsConn)).To(HaveOccurred())
+		Expect(connCheck(tlsSysConn)).To(HaveOccurred())
 	})
 
 	It("bad conn check", func() {
 		Expect(conn.Close()).NotTo(HaveOccurred())
-		Expect(connCheck(conn)).To(HaveOccurred())
+		Expect(connCheck(sysConn)).To(HaveOccurred())
 	})
 
 	It("bad tls conn check", func() {
 		Expect(tlsConn.Close()).NotTo(HaveOccurred())
-		Expect(connCheck(tlsConn)).To(HaveOccurred())
+		Expect(connCheck(tlsSysConn)).To(HaveOccurred())
 	})
 
 	It("check conn deadline", func() {
 		Expect(conn.SetDeadline(time.Now())).NotTo(HaveOccurred())
 		time.Sleep(time.Millisecond * 10)
-		Expect(connCheck(conn)).NotTo(HaveOccurred())
+		Expect(connCheck(sysConn)).To(HaveOccurred())
+
+		Expect(conn.SetDeadline(time.Now().Add(time.Minute))).NotTo(HaveOccurred())
+		time.Sleep(time.Millisecond * 10)
+		Expect(connCheck(sysConn)).NotTo(HaveOccurred())
 		Expect(conn.Close()).NotTo(HaveOccurred())
 	})
 })
diff --git a/internal/pool/pool.go b/internal/pool/pool.go
index 2125f3e13..9b84993cc 100644
--- a/internal/pool/pool.go
+++ b/internal/pool/pool.go
@@ -499,6 +499,8 @@ func (p *ConnPool) Close() error {
 	return firstErr
 }
 
+var zeroTime = time.Time{}
+
 func (p *ConnPool) isHealthyConn(cn *Conn) bool {
 	now := time.Now()
 
@@ -509,8 +511,12 @@ func (p *ConnPool) isHealthyConn(cn *Conn) bool {
 		return false
 	}
 
-	if connCheck(cn.netConn) != nil {
-		return false
+	if cn.sysConn != nil {
+		// reset previous timeout.
+		_ = cn.netConn.SetDeadline(zeroTime)
+		if connCheck(cn.sysConn) != nil {
+			return false
+		}
 	}
 
 	cn.SetUsedAt(now)

From 176ecde2eb04b727d2cc50f218ead334693c8e6f Mon Sep 17 00:00:00 2001
From: monkey92t <golang@88.com>
Date: Thu, 18 Jul 2024 23:11:38 +0800
Subject: [PATCH 2/2] fix: correct the function names

Signed-off-by: monkey92t <golang@88.com>
---
 internal/pool/conn.go | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/internal/pool/conn.go b/internal/pool/conn.go
index cc86ca251..d315c7937 100644
--- a/internal/pool/conn.go
+++ b/internal/pool/conn.go
@@ -39,7 +39,7 @@ func NewConn(netConn net.Conn) *Conn {
 	cn.bw = bufio.NewWriter(netConn)
 	cn.wr = proto.NewWriter(cn.bw)
 	cn.SetUsedAt(time.Now())
-	cn.setRawConn()
+	cn.setSysConn()
 	return cn
 }
 
@@ -56,10 +56,10 @@ func (cn *Conn) SetNetConn(netConn net.Conn) {
 	cn.netConn = netConn
 	cn.rd.Reset(netConn)
 	cn.bw.Reset(netConn)
-	cn.setRawConn()
+	cn.setSysConn()
 }
 
-func (cn *Conn) setRawConn() {
+func (cn *Conn) setSysConn() {
 	cn.sysConn = nil
 	conn := cn.netConn
 	if conn == nil {