Skip to content

Commit

Permalink
Fix: CounterConnection with ReadV/WriteV (#720)
Browse files Browse the repository at this point in the history
Co-authored-by: JimhHan <50871214+JimhHan@users.noreply.github.com>
  • Loading branch information
badO1a5A90 and JimhHan authored Sep 20, 2021
1 parent f2cb13a commit 24b637c
Show file tree
Hide file tree
Showing 53 changed files with 247 additions and 128 deletions.
14 changes: 8 additions & 6 deletions app/proxyman/inbound/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"sync/atomic"
"time"

"github.com/xtls/xray-core/transport/internet/stat"

"github.com/xtls/xray-core/app/proxyman"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
Expand Down Expand Up @@ -54,7 +56,7 @@ func getTProxyType(s *internet.MemoryStreamConfig) internet.SocketConfig_TProxyM
return s.SocketSettings.Tproxy
}

func (w *tcpWorker) callback(conn internet.Connection) {
func (w *tcpWorker) callback(conn stat.Connection) {
ctx, cancel := context.WithCancel(w.ctx)
sid := session.NewID()
ctx = session.ContextWithID(ctx, sid)
Expand All @@ -80,7 +82,7 @@ func (w *tcpWorker) callback(conn internet.Connection) {
}

if w.uplinkCounter != nil || w.downlinkCounter != nil {
conn = &internet.StatCouterConnection{
conn = &stat.CounterConnection{
Connection: conn,
ReadCounter: w.uplinkCounter,
WriteCounter: w.downlinkCounter,
Expand Down Expand Up @@ -117,7 +119,7 @@ func (w *tcpWorker) Proxy() proxy.Inbound {

func (w *tcpWorker) Start() error {
ctx := context.Background()
hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn internet.Connection) {
hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn stat.Connection) {
go w.callback(conn)
})
if err != nil {
Expand Down Expand Up @@ -436,13 +438,13 @@ type dsWorker struct {
ctx context.Context
}

func (w *dsWorker) callback(conn internet.Connection) {
func (w *dsWorker) callback(conn stat.Connection) {
ctx, cancel := context.WithCancel(w.ctx)
sid := session.NewID()
ctx = session.ContextWithID(ctx, sid)

if w.uplinkCounter != nil || w.downlinkCounter != nil {
conn = &internet.StatCouterConnection{
conn = &stat.CounterConnection{
Connection: conn,
ReadCounter: w.uplinkCounter,
WriteCounter: w.downlinkCounter,
Expand Down Expand Up @@ -482,7 +484,7 @@ func (w *dsWorker) Port() net.Port {
}
func (w *dsWorker) Start() error {
ctx := context.Background()
hub, err := internet.ListenUnix(ctx, w.address, w.stream, func(conn internet.Connection) {
hub, err := internet.ListenUnix(ctx, w.address, w.stream, func(conn stat.Connection) {
go w.callback(conn)
})
if err != nil {
Expand Down
8 changes: 5 additions & 3 deletions app/proxyman/outbound/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package outbound
import (
"context"

"github.com/xtls/xray-core/transport/internet/stat"

"github.com/xtls/xray-core/app/proxyman"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/mux"
Expand Down Expand Up @@ -158,7 +160,7 @@ func (h *Handler) Address() net.Address {
}

// Dial implements internet.Dialer.
func (h *Handler) Dial(ctx context.Context, dest net.Destination) (internet.Connection, error) {
func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connection, error) {
if h.senderSettings != nil {
if h.senderSettings.ProxySettings.HasTag() {
tag := h.senderSettings.ProxySettings.Tag
Expand Down Expand Up @@ -201,9 +203,9 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (internet.Conn
return h.getStatCouterConnection(conn), err
}

func (h *Handler) getStatCouterConnection(conn internet.Connection) internet.Connection {
func (h *Handler) getStatCouterConnection(conn stat.Connection) stat.Connection {
if h.uplinkCounter != nil || h.downlinkCounter != nil {
return &internet.StatCouterConnection{
return &stat.CounterConnection{
Connection: conn,
ReadCounter: h.downlinkCounter,
WriteCounter: h.uplinkCounter,
Expand Down
11 changes: 6 additions & 5 deletions app/proxyman/outbound/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"testing"

"github.com/xtls/xray-core/transport/internet/stat"

"github.com/xtls/xray-core/app/policy"
. "github.com/xtls/xray-core/app/proxyman/outbound"
"github.com/xtls/xray-core/app/stats"
Expand All @@ -12,7 +14,6 @@ import (
core "github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/outbound"
"github.com/xtls/xray-core/proxy/freedom"
"github.com/xtls/xray-core/transport/internet"
)

func TestInterfaces(t *testing.T) {
Expand Down Expand Up @@ -44,9 +45,9 @@ func TestOutboundWithoutStatCounter(t *testing.T) {
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
})
conn, _ := h.(*Handler).Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146))
_, ok := conn.(*internet.StatCouterConnection)
_, ok := conn.(*stat.CounterConnection)
if ok {
t.Errorf("Expected conn to not be StatCouterConnection")
t.Errorf("Expected conn to not be CounterConnection")
}
}

Expand All @@ -73,8 +74,8 @@ func TestOutboundWithStatCounter(t *testing.T) {
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
})
conn, _ := h.(*Handler).Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146))
_, ok := conn.(*internet.StatCouterConnection)
_, ok := conn.(*stat.CounterConnection)
if !ok {
t.Errorf("Expected conn to be StatCouterConnection")
t.Errorf("Expected conn to be CounterConnection")
}
}
36 changes: 32 additions & 4 deletions common/buf/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import (
"os"
"syscall"
"time"

"github.com/xtls/xray-core/features/stats"
"github.com/xtls/xray-core/transport/internet/stat"
)

// Reader extends io.Reader with MultiBuffer.
Expand All @@ -29,9 +32,17 @@ type Writer interface {
}

// WriteAllBytes ensures all bytes are written into the given writer.
func WriteAllBytes(writer io.Writer, payload []byte) error {
func WriteAllBytes(writer io.Writer, payload []byte, c stats.Counter) error {
wc := 0
defer func() {
if c != nil {
c.Add(int64(wc))
}
}()

for len(payload) > 0 {
n, err := writer.Write(payload)
wc += n
if err != nil {
return err
}
Expand Down Expand Up @@ -60,12 +71,18 @@ func NewReader(reader io.Reader) Reader {

_, isFile := reader.(*os.File)
if !isFile && useReadv {
var counter stats.Counter

if statConn, ok := reader.(*stat.CounterConnection); ok {
reader = statConn.Connection
counter = statConn.ReadCounter
}
if sc, ok := reader.(syscall.Conn); ok {
rawConn, err := sc.SyscallConn()
if err != nil {
newError("failed to get sysconn").Base(err).WriteToLog()
} else {
return NewReadVReader(reader, rawConn)
return NewReadVReader(reader, rawConn, counter)
}
}
}
Expand Down Expand Up @@ -104,13 +121,24 @@ func NewWriter(writer io.Writer) Writer {
return mw
}

if isPacketWriter(writer) {
var iConn = writer
if statConn, ok := writer.(*stat.CounterConnection); ok {
iConn = statConn.Connection
}

if isPacketWriter(iConn) {
return &SequentialWriter{
Writer: writer,
}
}

var counter stats.Counter

if statConn, ok := writer.(*stat.CounterConnection); ok {
counter = statConn.WriteCounter
}
return &BufferToBytesWriter{
Writer: writer,
Writer: iConn,
counter: counter,
}
}
14 changes: 12 additions & 2 deletions common/buf/readv_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"io"
"syscall"

"github.com/xtls/xray-core/features/stats"

"github.com/xtls/xray-core/common/platform"
)

Expand Down Expand Up @@ -53,17 +55,19 @@ type ReadVReader struct {
rawConn syscall.RawConn
mr multiReader
alloc allocStrategy
counter stats.Counter
}

// NewReadVReader creates a new ReadVReader.
func NewReadVReader(reader io.Reader, rawConn syscall.RawConn) *ReadVReader {
func NewReadVReader(reader io.Reader, rawConn syscall.RawConn, counter stats.Counter) *ReadVReader {
return &ReadVReader{
Reader: reader,
rawConn: rawConn,
alloc: allocStrategy{
current: 1,
},
mr: newMultiReader(),
mr: newMultiReader(),
counter: counter,
}
}

Expand Down Expand Up @@ -122,10 +126,16 @@ func (r *ReadVReader) ReadMultiBuffer() (MultiBuffer, error) {
if b.IsFull() {
r.alloc.Adjust(1)
}
if r.counter != nil && b != nil {
r.counter.Add(int64(b.Len()))
}
return MultiBuffer{b}, err
}

mb, err := r.readMulti()
if r.counter != nil && mb != nil {
r.counter.Add(int64(mb.Len()))
}
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion common/buf/readv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestReadvReader(t *testing.T) {
rawConn, err := conn.(*net.TCPConn).SyscallConn()
common.Must(err)

reader := NewReadVReader(conn, rawConn)
reader := NewReadVReader(conn, rawConn, nil)
var rmb MultiBuffer
for {
mb, err := reader.ReadMultiBuffer()
Expand Down
17 changes: 13 additions & 4 deletions common/buf/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"net"
"sync"

"github.com/xtls/xray-core/features/stats"

"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
)
Expand All @@ -13,7 +15,8 @@ import (
type BufferToBytesWriter struct {
io.Writer

cache [][]byte
counter stats.Counter
cache [][]byte
}

// WriteMultiBuffer implements Writer. This method takes ownership of the given buffer.
Expand All @@ -26,7 +29,7 @@ func (w *BufferToBytesWriter) WriteMultiBuffer(mb MultiBuffer) error {
}

if len(mb) == 1 {
return WriteAllBytes(w.Writer, mb[0].Bytes())
return WriteAllBytes(w.Writer, mb[0].Bytes(), w.counter)
}

if cap(w.cache) < len(mb) {
Expand All @@ -45,9 +48,15 @@ func (w *BufferToBytesWriter) WriteMultiBuffer(mb MultiBuffer) error {
}()

nb := net.Buffers(bs)

wc := int64(0)
defer func() {
if w.counter != nil {
w.counter.Add(wc)
}
}()
for size > 0 {
n, err := nb.WriteTo(w.Writer)
wc += n
if err != nil {
return err
}
Expand Down Expand Up @@ -173,7 +182,7 @@ func (w *BufferedWriter) flushInternal() error {
w.buffer = nil

if writer, ok := w.writer.(io.Writer); ok {
err := WriteAllBytes(writer, b.Bytes())
err := WriteAllBytes(writer, b.Bytes(), nil)
b.Release()
return err
}
Expand Down
2 changes: 1 addition & 1 deletion common/crypto/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func NewCryptionWriter(stream cipher.Stream, writer io.Writer) *CryptionWriter {
func (w *CryptionWriter) Write(data []byte) (int, error) {
w.stream.XORKeyStream(data, data)

if err := buf.WriteAllBytes(w.writer, data); err != nil {
if err := buf.WriteAllBytes(w.writer, data, nil); err != nil {
return 0, err
}
return len(data), nil
Expand Down
6 changes: 4 additions & 2 deletions proxy/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"io"
"sync"

"github.com/xtls/xray-core/transport/internet/stat"

"golang.org/x/net/dns/dnsmessage"

"github.com/xtls/xray-core/common"
Expand Down Expand Up @@ -104,7 +106,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
newError("handling DNS traffic to ", dest).WriteToLog(session.ExportIDToError(ctx))

conn := &outboundConn{
dialer: func() (internet.Connection, error) {
dialer: func() (stat.Connection, error) {
return d.Dial(ctx, dest)
},
connReady: make(chan struct{}, 1),
Expand Down Expand Up @@ -266,7 +268,7 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string,

type outboundConn struct {
access sync.Mutex
dialer func() (internet.Connection, error)
dialer func() (stat.Connection, error)

conn net.Conn
connReady chan struct{}
Expand Down
5 changes: 3 additions & 2 deletions proxy/dokodemo/dokodemo.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"sync/atomic"
"time"

"github.com/xtls/xray-core/transport/internet/stat"

"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/log"
Expand All @@ -18,7 +20,6 @@ import (
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/transport/internet"
)

func init() {
Expand Down Expand Up @@ -76,7 +77,7 @@ type hasHandshakeAddress interface {
}

// Process implements proxy.Inbound.
func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error {
func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
newError("processing connection from: ", conn.RemoteAddr()).AtDebug().WriteToLog(session.ExportIDToError(ctx))
dest := net.Destination{
Network: network,
Expand Down
Loading

0 comments on commit 24b637c

Please sign in to comment.