Skip to content

Commit

Permalink
Merge pull request #9 from zTrix/support_half_shutdown
Browse files Browse the repository at this point in the history
Support half open connection to ensure fd release as soon as tcp closed
  • Loading branch information
ihciah authored Feb 25, 2020
2 parents 53e8cc6 + 65d562d commit 1402df6
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 58 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
*.out

# idea
.idea
.idea
/bin
13 changes: 9 additions & 4 deletions block/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@ package block

import (
"encoding/binary"
"go.uber.org/atomic"
"io"

"go.uber.org/atomic"
)

const (
TypeConnect = iota
TypeDisconnect
TypeData

ShutdownRead = iota
ShutdownWrite
ShutdownBoth

HeaderSize = 1 + 4 + 4 + 4
DataSize = 16*1024 - 13
MaxSize = HeaderSize + DataSize
Expand Down Expand Up @@ -96,12 +101,12 @@ func NewDataBlocks(connectID uint32, blockID *atomic.Uint32, data []byte) []Bloc
return blocks
}

func NewDisconnectBlock(connectID uint32, blockID uint32) Block {
func NewDisconnectBlock(connectID uint32, blockID uint32, shutdownType uint8) Block {
return Block{
Type: TypeDisconnect,
ConnectionID: connectID,
BlockID: blockID,
BlockLength: 0,
BlockData: make([]byte, 0),
BlockLength: 1,
BlockData: []byte{shutdownType},
}
}
38 changes: 25 additions & 13 deletions client/client.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package client

import (
"context"
"github.com/ihciah/rabbit-tcp/logger"
"github.com/ihciah/rabbit-tcp/peer"
"github.com/ihciah/rabbit-tcp/tunnel"
"io"
"net"
"sync"
"time"

"github.com/ihciah/rabbit-tcp/connection"
"github.com/ihciah/rabbit-tcp/logger"
"github.com/ihciah/rabbit-tcp/peer"
"github.com/ihciah/rabbit-tcp/tunnel"
)

type Client struct {
Expand All @@ -22,7 +24,7 @@ func NewClient(tunnelNum int, endpoint string, cipher tunnel.Cipher) Client {
}
}

func (c *Client) Dial(address string) net.Conn {
func (c *Client) Dial(address string) connection.HalfOpenConn {
return c.peer.Dial(address)
}

Expand All @@ -40,28 +42,38 @@ func (c *Client) ServeForward(listen, dest string) error {
go func() {
c.logger.Infoln("Accepted a connection.")
connProxy := c.Dial(dest)
biRelay(conn, connProxy, c.logger)
biRelay(conn.(*net.TCPConn), connProxy, c.logger)
}()
}
}

func biRelay(left, right net.Conn, logger *logger.Logger) {
ctx, cancel := context.WithCancel(context.Background())
go relay(left, right, cancel, logger)
go relay(right, left, cancel, logger)
<-ctx.Done()
func biRelay(left, right connection.HalfOpenConn, logger *logger.Logger) {
var wg sync.WaitGroup
wg.Add(1)
go relay(left, right, &wg, logger, "local <- tunnel")
wg.Add(1)
go relay(right, left, &wg, logger, "local -> tunnel")
wg.Wait()
// logger.Errorf("===========> Close client biRelay")
_ = left.Close()
_ = right.Close()
}

func relay(dst, src net.Conn, cancel context.CancelFunc, logger *logger.Logger) {
func relay(dst, src connection.HalfOpenConn, wg *sync.WaitGroup, logger *logger.Logger, label string) {
defer wg.Done()
_, err := io.Copy(dst, src)
if err != nil {
_ = dst.SetDeadline(time.Now())
_ = src.SetDeadline(time.Now())
cancel()
_ = dst.Close()
_ = src.Close()
if err != io.EOF {
logger.Errorf("Error when relay client: %v.\n", err)
}
} else {
// logger.Debugf("!!!!!!!!!!!!!!!! %s : dst close write", label)
dst.CloseWrite()
// logger.Debugf("!!!!!!!!!!!!!!!! %s : src close read", label)
src.CloseRead()
}
}
9 changes: 5 additions & 4 deletions connection/block_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package connection

import (
"context"
"time"

"github.com/ihciah/rabbit-tcp/block"
"github.com/ihciah/rabbit-tcp/logger"
"go.uber.org/atomic"
"time"
)

// 1. Join blocks from chan to connection orderedRecvQueue
Expand Down Expand Up @@ -69,7 +70,7 @@ func (x *blockProcessor) OrderedRelay(connection Connection) {
case <-time.After(PacketWaitTimeoutSec * time.Second):
x.logger.Debugf("Packet wait time exceed of Connection %d.\n", connection.GetConnectionID())
if x.recvBlockID == x.lastRecvBlockID {
x.logger.Debugf("Connection %d is not in waiting status, continue.\n", connection.GetConnectionID())
x.logger.Debugf("recvBlockId == lastRecvBlockID(%d), but Connection %d is not in waiting status, continue.\n", x.recvBlockID, connection.GetConnectionID())
continue
}
x.logger.Warnf("Connection %d is going to be killed due to timeout.\n", connection.GetConnectionID())
Expand All @@ -89,6 +90,6 @@ func (x *blockProcessor) packConnect(address string, connectionID uint32) block.
return block.NewConnectBlock(connectionID, x.sendBlockID.Inc()-1, address)
}

func (x *blockProcessor) packDisconnect(connectionID uint32) block.Block {
return block.NewDisconnectBlock(connectionID, x.sendBlockID.Inc()-1)
func (x *blockProcessor) packDisconnect(connectionID uint32, shutdownType uint8) block.Block {
return block.NewDisconnectBlock(connectionID, x.sendBlockID.Inc()-1, shutdownType)
}
32 changes: 25 additions & 7 deletions connection/connection.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,37 @@
package connection

import (
"net"

"github.com/ihciah/rabbit-tcp/block"
"github.com/ihciah/rabbit-tcp/logger"
"go.uber.org/atomic"
"net"
)

type Connection interface {
type HalfOpenConn interface {
net.Conn
CloseRead() error
CloseWrite() error
}

type CloseWrite interface {
CloseWrite() error
}

type CloseRead interface {
CloseRead() error
}

type Connection interface {
HalfOpenConn
GetConnectionID() uint32
getOrderedRecvQueue() chan block.Block
getRecvQueue() chan block.Block

RecvBlock(block.Block)

SendConnect(address string)
SendDisconnect()
SendDisconnect(uint8)

OrderedRelay(connection Connection) // Run orderedRelay infinitely
Stop() // Stop all related relay and remove itself from connectionPool
Expand All @@ -33,6 +48,7 @@ type baseConnection struct {
}

func (bc *baseConnection) Stop() {
bc.logger.Debugf("connection stop\n")
bc.blockProcessor.removeFromPool()
}

Expand Down Expand Up @@ -62,11 +78,13 @@ func (bc *baseConnection) SendConnect(address string) {
bc.sendQueue <- blk
}

func (bc *baseConnection) SendDisconnect() {
bc.logger.Debugln("Send disconnect block.")
blk := bc.blockProcessor.packDisconnect(bc.connectionID)
func (bc *baseConnection) SendDisconnect(shutdownType uint8) {
bc.logger.Debugf("Send disconnect block: %v\n", shutdownType)
blk := bc.blockProcessor.packDisconnect(bc.connectionID, shutdownType)
bc.sendQueue <- blk
bc.Stop()
if shutdownType == block.ShutdownBoth {
bc.Stop()
}
}

func (bc *baseConnection) sendData(data []byte) {
Expand Down
50 changes: 40 additions & 10 deletions connection/inbound_connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package connection
import (
"context"
"fmt"
"github.com/ihciah/rabbit-tcp/block"
"github.com/ihciah/rabbit-tcp/logger"
"go.uber.org/atomic"
"io"
"math/rand"
"net"
"syscall"
"time"

"github.com/ihciah/rabbit-tcp/block"
"github.com/ihciah/rabbit-tcp/logger"
"go.uber.org/atomic"
)

type InboundConnection struct {
Expand All @@ -18,6 +20,9 @@ type InboundConnection struct {

writeCtx context.Context
readCtx context.Context

readClosed *atomic.Bool
writeClosed *atomic.Bool
}

func NewInboundConnection(sendQueue chan<- block.Block, ctx context.Context, removeFromPool context.CancelFunc) Connection {
Expand All @@ -32,9 +37,11 @@ func NewInboundConnection(sendQueue chan<- block.Block, ctx context.Context, rem
orderedRecvQueue: make(chan block.Block, OrderedRecvQueueSize),
logger: logger.NewLogger(fmt.Sprintf("[InboundConnection-%d]", connectionID)),
},
dataBuffer: NewByteRingBuffer(block.MaxSize),
readCtx: ctx,
writeCtx: ctx,
dataBuffer: NewByteRingBuffer(block.MaxSize),
readCtx: ctx,
writeCtx: ctx,
readClosed: atomic.NewBool(false),
writeClosed: atomic.NewBool(false),
}
c.logger.Infof("InboundConnection %d created.\n", connectionID)
return &c
Expand All @@ -52,7 +59,7 @@ func (c *InboundConnection) Read(b []byte) (n int, err error) {
}
}

if c.closed.Load() {
if c.closed.Load() || c.readClosed.Load() {
// Connection is closed, should read all data left in channel
for {
select {
Expand Down Expand Up @@ -118,8 +125,17 @@ func (c *InboundConnection) Read(b []byte) (n int, err error) {
func (c *InboundConnection) readBlock(blk *block.Block, readN *int, b []byte) (err error) {
switch blk.Type {
case block.TypeDisconnect:
c.closed.Store(true)
return io.EOF
// TODO: decide shutdown type
if blk.BlockData[0] == block.ShutdownBoth {
c.closed.Store(true)
return io.EOF
} else if blk.BlockData[0] == block.ShutdownWrite {
c.readClosed.Store(true)
return io.EOF
} else if blk.BlockData[0] == block.ShutdownRead {
c.writeClosed.Store(true)
return nil
}
case block.TypeData:
dst := b[*readN:]
if len(dst) < len(blk.BlockData) {
Expand All @@ -137,14 +153,28 @@ func (c *InboundConnection) readBlock(blk *block.Block, readN *int, b []byte) (e
func (c *InboundConnection) Write(b []byte) (n int, err error) {
// TODO: tag all blocks from b using WaitGroup
// TODO: and wait all blocks sent?
if c.writeClosed.Load() || c.closed.Load() {
return 0, syscall.EINVAL
}
c.sendData(b)
return len(b), nil
}

func (c *InboundConnection) Close() error {
if c.closed.CAS(false, true) {
c.SendDisconnect()
c.SendDisconnect(block.ShutdownBoth)
}
c.Stop()
return nil
}

func (c *InboundConnection) CloseRead() error {
c.SendDisconnect(block.ShutdownRead)
return nil
}

func (c *InboundConnection) CloseWrite() error {
c.SendDisconnect(block.ShutdownWrite)
return nil
}

Expand Down
Loading

0 comments on commit 1402df6

Please sign in to comment.