Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: CounterConnection with ReadV/WriteV #720

Merged
merged 4 commits into from
Sep 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions app/proxyman/inbound/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"sync/atomic"
"time"

"github.com/xtls/xray-core/transport/internet/stat"

"github.com/xtls/xray-core/app/proxyman"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
Expand Down Expand Up @@ -54,7 +56,7 @@ func getTProxyType(s *internet.MemoryStreamConfig) internet.SocketConfig_TProxyM
return s.SocketSettings.Tproxy
}

func (w *tcpWorker) callback(conn internet.Connection) {
func (w *tcpWorker) callback(conn stat.Connection) {
ctx, cancel := context.WithCancel(w.ctx)
sid := session.NewID()
ctx = session.ContextWithID(ctx, sid)
Expand All @@ -80,7 +82,7 @@ func (w *tcpWorker) callback(conn internet.Connection) {
}

if w.uplinkCounter != nil || w.downlinkCounter != nil {
conn = &internet.StatCouterConnection{
conn = &stat.CounterConnection{
Connection: conn,
ReadCounter: w.uplinkCounter,
WriteCounter: w.downlinkCounter,
Expand Down Expand Up @@ -117,7 +119,7 @@ func (w *tcpWorker) Proxy() proxy.Inbound {

func (w *tcpWorker) Start() error {
ctx := context.Background()
hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn internet.Connection) {
hub, err := internet.ListenTCP(ctx, w.address, w.port, w.stream, func(conn stat.Connection) {
go w.callback(conn)
})
if err != nil {
Expand Down Expand Up @@ -436,13 +438,13 @@ type dsWorker struct {
ctx context.Context
}

func (w *dsWorker) callback(conn internet.Connection) {
func (w *dsWorker) callback(conn stat.Connection) {
ctx, cancel := context.WithCancel(w.ctx)
sid := session.NewID()
ctx = session.ContextWithID(ctx, sid)

if w.uplinkCounter != nil || w.downlinkCounter != nil {
conn = &internet.StatCouterConnection{
conn = &stat.CounterConnection{
Connection: conn,
ReadCounter: w.uplinkCounter,
WriteCounter: w.downlinkCounter,
Expand Down Expand Up @@ -482,7 +484,7 @@ func (w *dsWorker) Port() net.Port {
}
func (w *dsWorker) Start() error {
ctx := context.Background()
hub, err := internet.ListenUnix(ctx, w.address, w.stream, func(conn internet.Connection) {
hub, err := internet.ListenUnix(ctx, w.address, w.stream, func(conn stat.Connection) {
go w.callback(conn)
})
if err != nil {
Expand Down
8 changes: 5 additions & 3 deletions app/proxyman/outbound/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package outbound
import (
"context"

"github.com/xtls/xray-core/transport/internet/stat"

"github.com/xtls/xray-core/app/proxyman"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/mux"
Expand Down Expand Up @@ -158,7 +160,7 @@ func (h *Handler) Address() net.Address {
}

// Dial implements internet.Dialer.
func (h *Handler) Dial(ctx context.Context, dest net.Destination) (internet.Connection, error) {
func (h *Handler) Dial(ctx context.Context, dest net.Destination) (stat.Connection, error) {
if h.senderSettings != nil {
if h.senderSettings.ProxySettings.HasTag() {
tag := h.senderSettings.ProxySettings.Tag
Expand Down Expand Up @@ -201,9 +203,9 @@ func (h *Handler) Dial(ctx context.Context, dest net.Destination) (internet.Conn
return h.getStatCouterConnection(conn), err
}

func (h *Handler) getStatCouterConnection(conn internet.Connection) internet.Connection {
func (h *Handler) getStatCouterConnection(conn stat.Connection) stat.Connection {
if h.uplinkCounter != nil || h.downlinkCounter != nil {
return &internet.StatCouterConnection{
return &stat.CounterConnection{
Connection: conn,
ReadCounter: h.downlinkCounter,
WriteCounter: h.uplinkCounter,
Expand Down
11 changes: 6 additions & 5 deletions app/proxyman/outbound/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"testing"

"github.com/xtls/xray-core/transport/internet/stat"

"github.com/xtls/xray-core/app/policy"
. "github.com/xtls/xray-core/app/proxyman/outbound"
"github.com/xtls/xray-core/app/stats"
Expand All @@ -12,7 +14,6 @@ import (
core "github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/outbound"
"github.com/xtls/xray-core/proxy/freedom"
"github.com/xtls/xray-core/transport/internet"
)

func TestInterfaces(t *testing.T) {
Expand Down Expand Up @@ -44,9 +45,9 @@ func TestOutboundWithoutStatCounter(t *testing.T) {
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
})
conn, _ := h.(*Handler).Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146))
_, ok := conn.(*internet.StatCouterConnection)
_, ok := conn.(*stat.CounterConnection)
if ok {
t.Errorf("Expected conn to not be StatCouterConnection")
t.Errorf("Expected conn to not be CounterConnection")
}
}

Expand All @@ -73,8 +74,8 @@ func TestOutboundWithStatCounter(t *testing.T) {
ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
})
conn, _ := h.(*Handler).Dial(ctx, net.TCPDestination(net.DomainAddress("localhost"), 13146))
_, ok := conn.(*internet.StatCouterConnection)
_, ok := conn.(*stat.CounterConnection)
if !ok {
t.Errorf("Expected conn to be StatCouterConnection")
t.Errorf("Expected conn to be CounterConnection")
}
}
36 changes: 32 additions & 4 deletions common/buf/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import (
"os"
"syscall"
"time"

"github.com/xtls/xray-core/features/stats"
"github.com/xtls/xray-core/transport/internet/stat"
)

// Reader extends io.Reader with MultiBuffer.
Expand All @@ -29,9 +32,17 @@ type Writer interface {
}

// WriteAllBytes ensures all bytes are written into the given writer.
func WriteAllBytes(writer io.Writer, payload []byte) error {
func WriteAllBytes(writer io.Writer, payload []byte, c stats.Counter) error {
wc := 0
defer func() {
if c != nil {
c.Add(int64(wc))
}
}()

for len(payload) > 0 {
n, err := writer.Write(payload)
wc += n
if err != nil {
return err
}
Expand Down Expand Up @@ -60,12 +71,18 @@ func NewReader(reader io.Reader) Reader {

_, isFile := reader.(*os.File)
if !isFile && useReadv {
var counter stats.Counter

if statConn, ok := reader.(*stat.CounterConnection); ok {
reader = statConn.Connection
counter = statConn.ReadCounter
}
if sc, ok := reader.(syscall.Conn); ok {
rawConn, err := sc.SyscallConn()
if err != nil {
newError("failed to get sysconn").Base(err).WriteToLog()
} else {
return NewReadVReader(reader, rawConn)
return NewReadVReader(reader, rawConn, counter)
}
}
}
Expand Down Expand Up @@ -104,13 +121,24 @@ func NewWriter(writer io.Writer) Writer {
return mw
}

if isPacketWriter(writer) {
var iConn = writer
if statConn, ok := writer.(*stat.CounterConnection); ok {
iConn = statConn.Connection
}

if isPacketWriter(iConn) {
return &SequentialWriter{
Writer: writer,
}
}

var counter stats.Counter

if statConn, ok := writer.(*stat.CounterConnection); ok {
counter = statConn.WriteCounter
}
return &BufferToBytesWriter{
Writer: writer,
Writer: iConn,
counter: counter,
}
}
14 changes: 12 additions & 2 deletions common/buf/readv_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"io"
"syscall"

"github.com/xtls/xray-core/features/stats"

"github.com/xtls/xray-core/common/platform"
)

Expand Down Expand Up @@ -53,17 +55,19 @@ type ReadVReader struct {
rawConn syscall.RawConn
mr multiReader
alloc allocStrategy
counter stats.Counter
}

// NewReadVReader creates a new ReadVReader.
func NewReadVReader(reader io.Reader, rawConn syscall.RawConn) *ReadVReader {
func NewReadVReader(reader io.Reader, rawConn syscall.RawConn, counter stats.Counter) *ReadVReader {
return &ReadVReader{
Reader: reader,
rawConn: rawConn,
alloc: allocStrategy{
current: 1,
},
mr: newMultiReader(),
mr: newMultiReader(),
counter: counter,
}
}

Expand Down Expand Up @@ -122,10 +126,16 @@ func (r *ReadVReader) ReadMultiBuffer() (MultiBuffer, error) {
if b.IsFull() {
r.alloc.Adjust(1)
}
if r.counter != nil && b != nil {
r.counter.Add(int64(b.Len()))
}
return MultiBuffer{b}, err
}

mb, err := r.readMulti()
if r.counter != nil && mb != nil {
r.counter.Add(int64(mb.Len()))
}
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion common/buf/readv_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestReadvReader(t *testing.T) {
rawConn, err := conn.(*net.TCPConn).SyscallConn()
common.Must(err)

reader := NewReadVReader(conn, rawConn)
reader := NewReadVReader(conn, rawConn, nil)
var rmb MultiBuffer
for {
mb, err := reader.ReadMultiBuffer()
Expand Down
17 changes: 13 additions & 4 deletions common/buf/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"net"
"sync"

"github.com/xtls/xray-core/features/stats"

"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
)
Expand All @@ -13,7 +15,8 @@ import (
type BufferToBytesWriter struct {
io.Writer

cache [][]byte
counter stats.Counter
cache [][]byte
}

// WriteMultiBuffer implements Writer. This method takes ownership of the given buffer.
Expand All @@ -26,7 +29,7 @@ func (w *BufferToBytesWriter) WriteMultiBuffer(mb MultiBuffer) error {
}

if len(mb) == 1 {
return WriteAllBytes(w.Writer, mb[0].Bytes())
return WriteAllBytes(w.Writer, mb[0].Bytes(), w.counter)
}

if cap(w.cache) < len(mb) {
Expand All @@ -45,9 +48,15 @@ func (w *BufferToBytesWriter) WriteMultiBuffer(mb MultiBuffer) error {
}()

nb := net.Buffers(bs)

wc := int64(0)
defer func() {
if w.counter != nil {
w.counter.Add(wc)
}
}()
for size > 0 {
n, err := nb.WriteTo(w.Writer)
wc += n
if err != nil {
return err
}
Expand Down Expand Up @@ -173,7 +182,7 @@ func (w *BufferedWriter) flushInternal() error {
w.buffer = nil

if writer, ok := w.writer.(io.Writer); ok {
err := WriteAllBytes(writer, b.Bytes())
err := WriteAllBytes(writer, b.Bytes(), nil)
b.Release()
return err
}
Expand Down
2 changes: 1 addition & 1 deletion common/crypto/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func NewCryptionWriter(stream cipher.Stream, writer io.Writer) *CryptionWriter {
func (w *CryptionWriter) Write(data []byte) (int, error) {
w.stream.XORKeyStream(data, data)

if err := buf.WriteAllBytes(w.writer, data); err != nil {
if err := buf.WriteAllBytes(w.writer, data, nil); err != nil {
return 0, err
}
return len(data), nil
Expand Down
6 changes: 4 additions & 2 deletions proxy/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"io"
"sync"

"github.com/xtls/xray-core/transport/internet/stat"

"golang.org/x/net/dns/dnsmessage"

"github.com/xtls/xray-core/common"
Expand Down Expand Up @@ -104,7 +106,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.
newError("handling DNS traffic to ", dest).WriteToLog(session.ExportIDToError(ctx))

conn := &outboundConn{
dialer: func() (internet.Connection, error) {
dialer: func() (stat.Connection, error) {
return d.Dial(ctx, dest)
},
connReady: make(chan struct{}, 1),
Expand Down Expand Up @@ -266,7 +268,7 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string,

type outboundConn struct {
access sync.Mutex
dialer func() (internet.Connection, error)
dialer func() (stat.Connection, error)

conn net.Conn
connReady chan struct{}
Expand Down
5 changes: 3 additions & 2 deletions proxy/dokodemo/dokodemo.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"sync/atomic"
"time"

"github.com/xtls/xray-core/transport/internet/stat"

"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/log"
Expand All @@ -18,7 +20,6 @@ import (
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/policy"
"github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/transport/internet"
)

func init() {
Expand Down Expand Up @@ -76,7 +77,7 @@ type hasHandshakeAddress interface {
}

// Process implements proxy.Inbound.
func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error {
func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
newError("processing connection from: ", conn.RemoteAddr()).AtDebug().WriteToLog(session.ExportIDToError(ctx))
dest := net.Destination{
Network: network,
Expand Down
Loading