Skip to content

Commit

Permalink
Fix tunnel pool manager bug; Update peerID exchange method.
Browse files Browse the repository at this point in the history
  • Loading branch information
ihciah committed Oct 16, 2019
1 parent ff765bd commit 9efaa82
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 33 deletions.
9 changes: 5 additions & 4 deletions tunnel_pool/const.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package tunnel_pool

const (
ErrorWaitSec = 3 // If a tunnel cannot be dialed, will wait for this period and retry infinitely
EmptyPoolDestroySec = 60 // The pool will be destroyed(server side) if no tunnel dialed in
SendQueueSize = 48 // SendQueue channel cap
RecvQueueSize = 48 // RecvQueue channel cap
ErrorWaitSec = 3 // If a tunnel cannot be dialed, will wait for this period and retry infinitely
TunnelBlockTimeoutSec = 8 // If a tunnel cannot send a block within the limit, will treat it a dead tunnel
EmptyPoolDestroySec = 60 // The pool will be destroyed(server side) if no tunnel dialed in
SendQueueSize = 48 // SendQueue channel cap
RecvQueueSize = 48 // RecvQueue channel cap
)
35 changes: 19 additions & 16 deletions tunnel_pool/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"github.com/ihciah/rabbit-tcp/logger"
"github.com/ihciah/rabbit-tcp/tunnel"
"go.uber.org/atomic"
"net"
"sync"
"time"
Expand All @@ -15,12 +16,12 @@ type Manager interface {
}

type ClientManager struct {
notifyLock sync.Mutex // Only one notify can run in the same time
tunnelNum int
endpoint string
peerID uint32
cipher tunnel.Cipher
logger *logger.Logger
decreaseNotifyLock sync.Mutex // Only one decrease notify can run at the same time
tunnelNum int
endpoint string
peerID uint32
cipher tunnel.Cipher
logger *logger.Logger
}

func NewClientManager(tunnelNum int, endpoint string, peerID uint32, cipher tunnel.Cipher) ClientManager {
Expand All @@ -35,11 +36,18 @@ func NewClientManager(tunnelNum int, endpoint string, peerID uint32, cipher tunn

// Keep tunnelPool size above tunnelNum
func (cm *ClientManager) DecreaseNotify(pool *TunnelPool) {
cm.notifyLock.Lock()
defer cm.notifyLock.Unlock()
cm.decreaseNotifyLock.Lock()
defer cm.decreaseNotifyLock.Unlock()
tunnelCount := len(pool.tunnelMapping)

for tunnelToCreate := cm.tunnelNum - tunnelCount; tunnelToCreate > 0; {
select {
case <-pool.ctx.Done():
// Have to return if pool cancel is called.
return
default:
}

cm.logger.Infof("Need %d new tunnels now.\n", tunnelToCreate)
conn, err := net.Dial("tcp", cm.endpoint)
if err != nil {
Expand All @@ -65,7 +73,7 @@ type ServerManager struct {
notifyLock sync.Mutex // Only one notify can run in the same time
removePeerFunc context.CancelFunc
cancelCountDownFunc context.CancelFunc
triggered bool
triggered atomic.Bool
logger *logger.Logger
}

Expand All @@ -78,28 +86,23 @@ func NewServerManager(removePeerFunc context.CancelFunc) ServerManager {

// If tunnelPool size is zero for more than EmptyPoolDestroySec, delete it
func (sm *ServerManager) Notify(pool *TunnelPool) {
sm.notifyLock.Lock()
defer sm.notifyLock.Unlock()
tunnelCount := len(pool.tunnelMapping)

if tunnelCount == 0 && !sm.triggered {
if tunnelCount == 0 && sm.triggered.CAS(false, true) {
var destroyAfterCtx context.Context
destroyAfterCtx, sm.cancelCountDownFunc = context.WithCancel(context.Background())
go func(*ServerManager) {
select {
case <-destroyAfterCtx.Done():
sm.triggered = false
sm.logger.Debugln("ServerManager notify canceled.")
return
case <-time.After(EmptyPoolDestroySec * time.Second):
sm.logger.Infoln("ServerManager will be destroyed.")
sm.removePeerFunc()
return
}
}(sm)
}

if tunnelCount != 0 && sm.triggered {
if tunnelCount != 0 && sm.triggered.CAS(true, false) {
sm.cancelCountDownFunc()
}
}
Expand Down
4 changes: 2 additions & 2 deletions tunnel_pool/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (tp *TunnelPool) AddTunnel(tunnel *Tunnel) {
defer tp.mutex.Unlock()

tp.tunnelMapping[tunnel.tunnelID] = tunnel
go tp.manager.Notify(tp)
tp.manager.Notify(tp)

tunnel.ctx, tunnel.cancel = context.WithCancel(tp.ctx)
go func() {
Expand All @@ -65,7 +65,7 @@ func (tp *TunnelPool) RemoveTunnel(tunnel *Tunnel) {
if tunnel, ok := tp.tunnelMapping[tunnel.tunnelID]; ok {
delete(tp.tunnelMapping, tunnel.tunnelID)
tp.manager.Notify(tp)
tp.manager.DecreaseNotify(tp)
go tp.manager.DecreaseNotify(tp)
}
}

Expand Down
63 changes: 52 additions & 11 deletions tunnel_pool/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import (
"bytes"
"context"
"encoding/binary"
"errors"
"fmt"
"github.com/ihciah/rabbit-tcp/block"
"github.com/ihciah/rabbit-tcp/logger"
"github.com/ihciah/rabbit-tcp/tunnel"
"io"
"math/rand"
"net"
"time"
)

type Tunnel struct {
Expand All @@ -25,12 +27,12 @@ type Tunnel struct {
// Create a new tunnel from a net.Conn and cipher with random tunnelID
func NewActiveTunnel(conn net.Conn, ciph tunnel.Cipher, peerID uint32) (Tunnel, error) {
tun := newTunnelWithID(conn, ciph, peerID)
return tun, tun.sendPeerID()
return tun, tun.activeExchangePeerID()
}

func NewPassiveTunnel(conn net.Conn, ciph tunnel.Cipher) (Tunnel, error) {
tun := newTunnelWithID(conn, ciph, 0)
return tun, tun.recvPeerID()
return tun, tun.passiveExchangePeerID()
}

// Create a new tunnel from a net.Conn and cipher with given tunnelID
Expand All @@ -46,28 +48,63 @@ func newTunnelWithID(conn net.Conn, ciph tunnel.Cipher, peerID uint32) Tunnel {
return tun
}

func (tunnel *Tunnel) sendPeerID() error {
func (tunnel *Tunnel) activeExchangePeerID() (err error) {
err = tunnel.sendPeerID(tunnel.peerID)
if err != nil {
tunnel.logger.Errorf("Cannot exchange peerID(send failed: %v).\n", err)
return err
}
peerID, err := tunnel.recvPeerID()
if err != nil {
tunnel.logger.Errorf("Cannot exchange peerID(recv failed: %v).\n", err)
return err
}
if tunnel.peerID != peerID {
tunnel.logger.Errorf("Cannot exchange peerID(local: %d, remote: %d).\n", tunnel.peerID, peerID)
return errors.New("invalid exchanging")
}
tunnel.logger.Infoln("PeerID exchange successfully.")
return
}

func (tunnel *Tunnel) passiveExchangePeerID() (err error) {
peerID, err := tunnel.recvPeerID()
if err != nil {
tunnel.logger.Errorf("Cannot exchange peerID(recv failed: %v).\n", err)
return err
}
err = tunnel.sendPeerID(peerID)
if err != nil {
tunnel.logger.Errorf("Cannot exchange peerID(send failed: %v).\n", err)
return err
}
tunnel.peerID = peerID
tunnel.logger.Infoln("PeerID exchange successfully.")
return
}

func (tunnel *Tunnel) sendPeerID(peerID uint32) error {
peerIDBuffer := make([]byte, 4)
binary.LittleEndian.PutUint32(peerIDBuffer, tunnel.peerID)
binary.LittleEndian.PutUint32(peerIDBuffer, peerID)
_, err := io.CopyN(tunnel.Conn, bytes.NewReader(peerIDBuffer), 4)
if err != nil {
tunnel.logger.Errorf("Peer id sent with error:%v.\n", err)
return err
}
tunnel.logger.Debugln("Peer id sent.")
tunnel.logger.Infoln("Peer id sent.")
return nil
}

func (tunnel *Tunnel) recvPeerID() error {
func (tunnel *Tunnel) recvPeerID() (uint32, error) {
peerIDBuffer := make([]byte, 4)
_, err := io.ReadFull(tunnel.Conn, peerIDBuffer)
if err != nil {
tunnel.logger.Errorf("Peer id recv with error:%v.\n", err)
return err
return 0, err
}
tunnel.peerID = binary.LittleEndian.Uint32(peerIDBuffer)
tunnel.logger.Debugln("Peer id recv.")
return nil
peerID := binary.LittleEndian.Uint32(peerIDBuffer)
tunnel.logger.Infoln("Peer id recv.")
return peerID, nil
}

// Read block from send channel, pack it and send
Expand Down Expand Up @@ -103,6 +140,8 @@ func (tunnel *Tunnel) OutboundRelay(normalQueue, retryQueue chan block.Block) {
func (tunnel *Tunnel) packThenSend(blk block.Block, retryQueue chan block.Block) {
dataToSend := blk.Pack()
reader := bytes.NewReader(dataToSend)

tunnel.Conn.SetWriteDeadline(time.Now().Add(TunnelBlockTimeoutSec * time.Second))
n, err := io.Copy(tunnel.Conn, reader)
if err != nil || n != int64(len(dataToSend)) {
tunnel.logger.Warnf("Error when send bytes to tunnel: (n: %d, error: %v).\n", n, err)
Expand All @@ -113,6 +152,7 @@ func (tunnel *Tunnel) packThenSend(blk block.Block, retryQueue chan block.Block)
}()
// Use new goroutine to avoid channel blocked
} else {
tunnel.Conn.SetWriteDeadline(time.Time{})
tunnel.logger.Debugf("Copied data to tunnel successfully(n: %d).\n", n)
}
}
Expand All @@ -125,8 +165,9 @@ func (tunnel *Tunnel) InboundRelay(output chan<- block.Block) {
case <-tunnel.ctx.Done():
// Should read all before leave, or packet will be lost
for {
// Will never be blocked because the tunnel is closed
blk, err := block.NewBlockFromReader(tunnel.Conn)
if err != nil {
if err == nil {
tunnel.logger.Debugf("Block received from tunnel(type: %d) successfully after close.\n", blk.Type)
output <- *blk
} else {
Expand Down

0 comments on commit 9efaa82

Please sign in to comment.