Skip to content

Commit

Permalink
core/connpool: pool iff conn confirms to syscall.Conn
Browse files Browse the repository at this point in the history
  • Loading branch information
ignoramous committed Oct 20, 2024
1 parent 5f8f0f2 commit 6b50497
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 61 deletions.
168 changes: 111 additions & 57 deletions intra/core/connpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package core

import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
Expand All @@ -18,10 +19,11 @@ import (
"time"

"github.com/celzero/firestack/intra/log"
"github.com/miekg/dns"
"golang.org/x/sys/unix"
)

const useread = false // always false; here for doc purposes
const useread = false // never used; for documentation only
const poolcapacity = 8 // default capacity
const maxattempts = poolcapacity / 2 // max attempts to retrieve a conn from pool
const Nobody = uintptr(0) // nobody
Expand All @@ -33,7 +35,11 @@ var (
kaidle = int(maxttl / 5 / time.Second) // 8m / 5 => 96s
kainterval = int(maxttl / 10 / time.Second) // 8m / 10 => 48s
)
var errUnexpectedRead error = errors.New("pool: unexpected read")

var (
errUnexpectedRead = errors.New("pool: unexpected read")
errNotSyscallConn = errors.New("pool: not a syscall.Conn")
)

type superpool[T comparable] struct {
quit context.CancelFunc
Expand All @@ -49,8 +55,9 @@ type MultConnPool[T comparable] struct {

func NewMultConnPool[T comparable](ctx context.Context) *MultConnPool[T] {
return &MultConnPool[T]{
ctx: ctx,
m: make(map[T]*superpool[T]),
ctx: ctx,
m: make(map[T]*superpool[T]),
scrubtime: time.Now(),
}
}

Expand Down Expand Up @@ -83,7 +90,7 @@ func (m *MultConnPool[T]) scrub() {
delete(m.m, id)
} else {
nscrubbed++
Go("poo.scrub", super.pool.scrub)
Go("pool.scrub", super.pool.scrub)
}
}

Expand All @@ -107,7 +114,7 @@ func (m *MultConnPool[T]) Get(id T) net.Conn {
return nil
}

func (m *MultConnPool[T]) Put(id T, conn net.Conn) bool {
func (m *MultConnPool[T]) Put(id T, conn net.Conn) (ok bool) {
if IsZero(id) || IsNil(conn) {
return false
}
Expand All @@ -130,24 +137,49 @@ func (m *MultConnPool[T]) Put(id T, conn net.Conn) bool {
return super.pool.Put(conn)
}

type timedconn struct {
type agingconn struct {
c net.Conn
sc syscall.Conn
dob time.Time
str string
}

func newAgingConn(c net.Conn) agingconn {
s := conn2str(c)
if sc, ok := c.(syscall.Conn); ok {
return agingconn{c, sc, time.Now(), s}
} else if dc, ok := c.(*dns.Conn); ok {
if tc, ok := dc.Conn.(*tls.Conn); ok {
if sc, ok := tc.NetConn().(syscall.Conn); ok {
return agingconn{c, sc, time.Now(), s}
}
log.W("pool: dns.Conn not sys.Conn: %T", tc.NetConn())
} else if dc, ok := dc.Conn.(syscall.Conn); ok {
return agingconn{c, dc, time.Now(), s}
}
log.W("pool: dns.Conn not sys.Conn: %T", dc.Conn)
} else if tc, ok := c.(*tls.Conn); ok {
if sc, ok := tc.NetConn().(syscall.Conn); ok {
return agingconn{c, sc, time.Now(), s}
}
log.W("pool: conn not a sys.Conn: %T", c)
}
return agingconn{c, nil, time.Time{}, s}
}

// github.com/redis/go-redis/blob/d9eeed13/internal/pool/pool.go
type ConnPool[T comparable] struct {
ctx context.Context
id T
p chan timedconn // never closed
p chan agingconn // never closed
closed atomic.Bool
}

func NewConnPool[T comparable](ctx context.Context, id T) *ConnPool[T] {
c := &ConnPool[T]{
ctx: ctx,
id: id,
p: make(chan timedconn, poolcapacity),
p: make(chan agingconn, poolcapacity),
}

context.AfterFunc(ctx, c.clean)
Expand All @@ -168,13 +200,13 @@ func (c *ConnPool[T]) Get() (zz net.Conn) {
for i < maxattempts {
i++
select {
case tconn := <-c.p:
case aconn := <-c.p:
// if readable, return conn regardless of its freshness
if readable(tconn.c) {
nokeepalive(tconn.c)
return tconn.c
if aconn.readable() {
aconn.nokeepalive()
return aconn.c
}
CloseConn(tconn.c)
(&aconn).close()
case <-ctx.Done():
return // signal stop
default:
Expand All @@ -192,19 +224,30 @@ func (c *ConnPool[T]) Get() (zz net.Conn) {
return pooled
}

// Put puts conn back in the pool.
// Put takes ownership of the conn regardless of the return value.
func (c *ConnPool[T]) Put(conn net.Conn) (ok bool) {
defer func() {
if !ok {
CloseConn(conn)
}
}()

if c.closed.Load() {
return
}
if c.full() {
return
}

tconn := timedconn{conn, time.Now()}
aconn := newAgingConn(conn)
if !aconn.readable() {
return false
}

select {
case c.p <- tconn:
cleardeadline(conn) // reset any previous timeout
keepalive(conn)
case c.p <- aconn:
aconn.keepalive()
return true
case <-c.ctx.Done(): // stop
return false
Expand All @@ -218,18 +261,18 @@ func (c *ConnPool[T]) empty() bool {
}

func (c *ConnPool[T]) full() bool {
return len(c.p) >= poolcapacity
return len(c.p) > poolcapacity
}

func (c *ConnPool[T]) clean() {
// defer close(c.p)
// todo: defer close(c.p)

ok := c.closed.CompareAndSwap(false, true)
log.I("pool: %v closed? %t", c.id, ok)
for {
select {
case tconn := <-c.p:
CloseConn(tconn.c)
case aconn := <-c.p:
(&aconn).close()
default:
return
}
Expand All @@ -241,67 +284,89 @@ func (c *ConnPool[T]) scrub() {
return
}

staged := make([]timedconn, 0)
staged := make([]agingconn, 0)
defer func() {
for _, tconn := range staged {
for _, aconn := range staged {
kept := false
select {
case <-c.ctx.Done(): // closed
default:
select {
case c.p <- tconn: // put it back in
case c.p <- aconn: // put it back in
kept = true
case <-c.ctx.Done(): // closed
default: // pool full
}
}
if !kept {
CloseConn(tconn.c)
(&aconn).close()
}
}
}()

for {
select {
case tconn := <-c.p:
if fresh(tconn.dob) && readable(tconn.c) {
staged = append(staged, tconn)
case aconn := <-c.p:
if aconn.ok() {
staged = append(staged, aconn)
} else {
CloseConn(tconn.c)
(&aconn).close()
} // next
case <-c.ctx.Done():
case <-c.ctx.Done(): // closed
return
default:
default: // empty
return
}
}
}

func fresh(t time.Time) bool {
return time.Since(t) < maxttl
func (a agingconn) ok() bool {
return a.fresh() &&
a.readable()
}

func (a agingconn) fresh() bool {
return a.dob != (time.Time{}) &&
time.Since(a.dob) < maxttl
}

func (a *agingconn) close() {
a.dob = time.Time{}
CloseConn(a.c)
}

// github.com/golang/go/issues/15735
func readable(c net.Conn) bool {
var err error
id := conn2str(c)
// must use syscall.Conn: github.com/golang/go/issues/65143
switch x := c.(type) {
case syscall.Conn:
err = canread(x)
default:
}
logev(err)("pool: %s readable? %t; err? %v", id, err == nil, err)
func (a agingconn) readable() bool {
err := a.canread()

logev(err)("pool: %s sysconn? %T readable? %t; err? %v",
a.str, a.c, err == nil, err)
return err == nil
}

func (a agingconn) keepalive() bool {
cleardeadline(a.c) // reset any previous timeout
return SetKeepAliveConfigSockOpt(a.c, kaidle, kainterval)
}

func (a agingconn) nokeepalive() bool {
if tc, ok := a.c.(*net.TCPConn); ok {
return tc.SetKeepAlive(false) == nil
}
return false
}

// github.com/go-sql-driver/mysql/blob/f20b28636/conncheck.go
// github.com/redis/go-redis/blob/cc9bcb0c0/internal/pool/conn_check.go
func canread(sc syscall.Conn) error {
func (a agingconn) canread() error {
if a.sc == nil {
return errNotSyscallConn
}

var checkErr error
var ctlErr error

raw, err := sc.SyscallConn()
raw, err := a.sc.SyscallConn()
if err != nil {
return fmt.Errorf("pool: sysconn: %w", err)
}
Expand Down Expand Up @@ -344,17 +409,6 @@ func canread(sc syscall.Conn) error {
return errors.Join(ctlErr, checkErr) // may return nil
}

func keepalive(c net.Conn) bool {
return SetKeepAliveConfigSockOpt(c, kaidle, kainterval)
}

func nokeepalive(c net.Conn) bool {
if tc, ok := c.(*net.TCPConn); ok {
return tc.SetKeepAlive(false) == nil
}
return false
}

func logev(err error) log.LogFn {
return logevif(err != nil)
}
Expand Down
7 changes: 7 additions & 0 deletions intra/core/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package core
import (
"io"
"net"
"syscall"
"time"
)

Expand Down Expand Up @@ -67,9 +68,15 @@ type UDPConn interface {
// DuplexConn represents a bidirectional stream socket.
type DuplexConn interface {
TCPConn
PoolableConn
io.ReaderFrom
}

// so it can be pooled by ConnPool.
type PoolableConn interface {
syscall.Conn
}

type ICMPConn interface {
net.PacketConn
}
Expand Down
9 changes: 9 additions & 0 deletions intra/dialers/direct_split.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"io"
"net"
"sync/atomic"
"syscall"
"time"

"github.com/celzero/firestack/intra/core"
Expand Down Expand Up @@ -143,3 +144,11 @@ func (s *splitter) CloseRead() error { core.CloseTCPRead(s.conn); return nil }

// CloseWrite implements DuplexConn.
func (s *splitter) CloseWrite() error { core.CloseTCPWrite(s.conn); return nil }

// SyscallConn implements syscall.Conn.
func (s *splitter) SyscallConn() (syscall.RawConn, error) {
if c := s.conn; c != nil {
return c.SyscallConn()
}
return nil, syscall.EINVAL
}
Loading

0 comments on commit 6b50497

Please sign in to comment.