diff --git a/tunnel/destination.go b/tunnel/destination.go index caee9192..c82dbf10 100644 --- a/tunnel/destination.go +++ b/tunnel/destination.go @@ -16,6 +16,7 @@ package tunnel import ( "io" + "sync" "google.golang.org/protobuf/proto" @@ -23,10 +24,23 @@ import ( "github.com/seqsense/aws-iot-device-sdk-go/v6/tunnel/msg" ) -func proxyDestination(ws io.ReadWriter, dialer Dialer, eh ErrorHandler) error { +func proxyDestination(ws io.ReadWriter, dialer Dialer, eh ErrorHandler, stat Stat) error { + var muConns sync.Mutex conns := make(map[int32]io.ReadWriteCloser) sz := make([]byte, 2) b := make([]byte, 8192) + + updateStat := func() { + if stat != nil { + muConns.Lock() + n := len(conns) + muConns.Unlock() + stat.Update(func(stat *Statistics) { + stat.NumConn = n + }) + } + } + for { if _, err := io.ReadFull(ws, sz); err != nil { if err == io.EOF { @@ -62,28 +76,47 @@ func proxyDestination(ws io.ReadWriter, dialer Dialer, eh ErrorHandler) error { continue } + muConns.Lock() conns[m.StreamId] = conn - go readProxy(ws, conn, m.StreamId, eh) + muConns.Unlock() + go func() { + readProxy(ws, conn, m.StreamId, eh) + muConns.Lock() + if conn, ok := conns[m.StreamId]; ok { + _ = conn.Close() + delete(conns, m.StreamId) + } + muConns.Unlock() + updateStat() + }() case msg.Message_STREAM_RESET: + muConns.Lock() if conn, ok := conns[m.StreamId]; ok { _ = conn.Close() delete(conns, m.StreamId) } + muConns.Unlock() case msg.Message_SESSION_RESET: + muConns.Lock() for id, c := range conns { _ = c.Close() delete(conns, id) } + muConns.Unlock() return io.EOF case msg.Message_DATA: - if conn, ok := conns[m.StreamId]; ok { + muConns.Lock() + conn, ok := conns[m.StreamId] + muConns.Unlock() + if ok { if _, err := conn.Write(m.Payload); err != nil { eh.HandleError(ioterr.New(err, "writing message")) } } } + updateStat() } } diff --git a/tunnel/proxy.go b/tunnel/proxy.go index 00399d5a..1b75ffd4 100644 --- a/tunnel/proxy.go +++ b/tunnel/proxy.go @@ -57,7 +57,7 @@ func ProxyDestination(dialer Dialer, endpoint, token string, opts ...ProxyOption pingCancel := newPinger(ws, opt.PingPeriod) defer pingCancel() - return proxyDestination(ws, dialer, opt.ErrorHandler) + return proxyDestination(ws, dialer, opt.ErrorHandler, opt.Stat) } // ProxySource proxies TCP connection from local socket to @@ -72,7 +72,7 @@ func ProxySource(listener net.Listener, endpoint, token string, opts ...ProxyOpt pingCancel := newPinger(ws, opt.PingPeriod) defer pingCancel() - return proxySource(ws, listener, opt.ErrorHandler) + return proxySource(ws, listener, opt.ErrorHandler, opt.Stat) } func openProxyConn(endpoint, mode, token string, opts ...ProxyOption) (*websocket.Conn, *ProxyOptions, error) { @@ -140,6 +140,7 @@ type ProxyOptions struct { Scheme string ErrorHandler ErrorHandler PingPeriod time.Duration + Stat Stat } func (o *ProxyOptions) validate() error { @@ -166,3 +167,11 @@ func WithPingPeriod(d time.Duration) ProxyOption { return nil } } + +// WithStat enables statistics. +func WithStat(stat Stat) ProxyOption { + return func(opt *ProxyOptions) error { + opt.Stat = stat + return nil + } +} diff --git a/tunnel/proxy_test.go b/tunnel/proxy_test.go index ee5ac22c..45720b31 100644 --- a/tunnel/proxy_test.go +++ b/tunnel/proxy_test.go @@ -18,7 +18,9 @@ import ( "errors" "io" "net" + "reflect" "sync" + "sync/atomic" "testing" "time" @@ -34,90 +36,107 @@ var ( func TestProxyDestination(t *testing.T) { t.Run("Success", func(t *testing.T) { - tca, tcb := net.Pipe() - ca, cb := net.Pipe() - - var wg sync.WaitGroup - wg.Add(1) - defer wg.Wait() - - go func() { - defer wg.Done() - err := proxyDestination(tca, - func() (io.ReadWriteCloser, error) { return cb, nil }, - nil, - ) + test := func(t *testing.T, stat Stat) { + tca, tcb := net.Pipe() + ca, cb := net.Pipe() + + var wg sync.WaitGroup + wg.Add(1) + defer wg.Wait() + + go func() { + defer wg.Done() + err := proxyDestination(tca, + func() (io.ReadWriteCloser, error) { return cb, nil }, + nil, stat, + ) + if err != nil { + t.Error(err) + } + }() + + payload1 := "the payload 1" + payload2 := "the payload 2" + + // Check source to destination + msgs := []*msg.Message{ + { + Type: msg.Message_STREAM_START, + StreamId: 1, + }, + { + Type: msg.Message_DATA, + StreamId: 1, + Payload: []byte(payload1), + }, + } + for _, m := range msgs { + b, err := proto.Marshal(m) + if err != nil { + t.Fatal(err) + } + l := len(b) + _, err = tcb.Write(append( + []byte{byte(l >> 8), byte(l)}, + b..., + )) + if err != nil { + t.Fatal(err) + } + } + bRecv := make([]byte, 100) + n, err := ca.Read(bRecv) if err != nil { - t.Error(err) + t.Fatal(err) + } + if string(bRecv[:n]) != payload1 { + t.Errorf("payload differs, expected: %s, got: %s", payload1, string(bRecv[:n])) } - }() - - payload1 := "the payload 1" - payload2 := "the payload 2" - // Check source to destination - msgs := []*msg.Message{ - { - Type: msg.Message_STREAM_START, - StreamId: 1, - }, - { - Type: msg.Message_DATA, - StreamId: 1, - Payload: []byte(payload1), - }, - } - for _, m := range msgs { - b, err := proto.Marshal(m) - if err != nil { + // Check destination to source + if _, err := ca.Write([]byte(payload2)); err != nil { t.Fatal(err) } - l := len(b) - _, err = tcb.Write(append( - []byte{byte(l >> 8), byte(l)}, - b..., - )) - if err != nil { + sz := make([]byte, 2) + if _, err := io.ReadFull(tcb, sz); err != nil { t.Fatal(err) } - } - bRecv := make([]byte, 100) - n, err := ca.Read(bRecv) - if err != nil { - t.Fatal(err) - } - if string(bRecv[:n]) != payload1 { - t.Errorf("payload differs, expected: %s, got: %s", payload1, string(bRecv[:n])) - } + bSent := make([]byte, int(sz[0])<<8|int(sz[1])) + if _, err := io.ReadFull(tcb, bSent); err != nil { + t.Fatal(err) + } + m := &msg.Message{} + if err := proto.Unmarshal(bSent, m); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + msgExpected := &msg.Message{ + Type: msg.Message_DATA, + StreamId: 1, + Payload: []byte(payload2), + } + if !proto.Equal(msgExpected, m) { + t.Errorf("message differes, expected: %v, got: %v", msgExpected, m) + } - // Check destination to source - if _, err := ca.Write([]byte(payload2)); err != nil { - t.Fatal(err) - } - sz := make([]byte, 2) - if _, err := io.ReadFull(tcb, sz); err != nil { - t.Fatal(err) - } - bSent := make([]byte, int(sz[0])<<8|int(sz[1])) - if _, err := io.ReadFull(tcb, bSent); err != nil { - t.Fatal(err) - } - m := &msg.Message{} - if err := proto.Unmarshal(bSent, m); err != nil { - t.Fatalf("unmarshal failed: %v", err) - } - msgExpected := &msg.Message{ - Type: msg.Message_DATA, - StreamId: 1, - Payload: []byte(payload2), - } - if !proto.Equal(msgExpected, m) { - t.Errorf("message differes, expected: %v, got: %v", msgExpected, m) + if err := tcb.Close(); err != nil { + t.Fatal(err) + } } + t.Run("WithoutStat", func(t *testing.T) { + test(t, nil) + }) + t.Run("WithStat", func(t *testing.T) { + stat := NewStat() + test(t, stat) - if err := tcb.Close(); err != nil { - t.Fatal(err) - } + s := stat.Statistics() + expected := Statistics{ + NumConn: 1, + } + if !reflect.DeepEqual(expected, s) { + t.Errorf("Expected stat: %+v, got: %+v", expected, s) + } + }) }) t.Run("DialError", func(t *testing.T) { tca, tcb := net.Pipe() @@ -134,6 +153,7 @@ func TestProxyDestination(t *testing.T) { ErrorHandlerFunc(func(err error) { chErr <- err }), + nil, ); err != nil { t.Error(err) } @@ -169,6 +189,9 @@ func TestProxyDestination(t *testing.T) { tca, tcb := net.Pipe() ca, cb := net.Pipe() + cbWithErr := &errorConn{Conn: cb} + cb = cbWithErr + var wg sync.WaitGroup defer wg.Wait() @@ -181,6 +204,7 @@ func TestProxyDestination(t *testing.T) { ErrorHandlerFunc(func(err error) { chErr <- err }), + nil, ); err != nil { t.Error(err) } @@ -197,7 +221,8 @@ func TestProxyDestination(t *testing.T) { if _, err = tcb.Write(append([]byte{byte(l >> 8), byte(l)}, b...)); err != nil { t.Fatal(err) } - ca.Close() + cbWithErr.err.Store(io.ErrClosedPipe) + defer ca.Close() b, err = proto.Marshal(&msg.Message{ Type: msg.Message_DATA, @@ -244,7 +269,7 @@ func TestProxyDestination(t *testing.T) { } }() - err := proxyDestination(ca, nil, nil) + err := proxyDestination(ca, nil, nil, nil) var ie *ioterr.Error if !errors.As(err, &ie) { @@ -265,6 +290,7 @@ func TestProxyDestination(t *testing.T) { go func() { if err := proxyDestination(ca, nil, ErrorHandlerFunc(func(err error) { chErr <- err }), + nil, ); err != nil { t.Error(err) } @@ -294,103 +320,120 @@ func TestProxyDestination(t *testing.T) { func TestProxySource(t *testing.T) { t.Run("Success", func(t *testing.T) { - tca, tcb := net.Pipe() - ca, cb := net.Pipe() + test := func(t *testing.T, stat Stat) { + tca, tcb := net.Pipe() + ca, cb := net.Pipe() + + var wg sync.WaitGroup + defer wg.Wait() + + wg.Add(1) + go func() { + defer wg.Done() + var i int + if err := proxySource(tca, + acceptFunc(func() (net.Conn, error) { + if i > 0 { + return nil, errors.New("done") + } + i++ + return cb, nil + }), + nil, stat, + ); err != nil { + t.Error(err) + } + }() - var wg sync.WaitGroup - defer wg.Wait() + payload1 := "the payload 1" + payload2 := "the payload 2" - wg.Add(1) - go func() { - defer wg.Done() - var i int - if err := proxySource(tca, - acceptFunc(func() (net.Conn, error) { - if i > 0 { - return nil, errors.New("done") - } - i++ - return cb, nil - }), - nil, - ); err != nil { - t.Error(err) + wg.Add(1) + go func() { + defer wg.Done() + if _, err := ca.Write([]byte(payload1)); err != nil { + t.Error(err) + } + }() + + // Check source to destination + msgsExpected := []*msg.Message{ + { + Type: msg.Message_STREAM_START, + StreamId: 1, + }, + { + Type: msg.Message_DATA, + StreamId: 1, + Payload: []byte(payload1), + }, } - }() - - payload1 := "the payload 1" - payload2 := "the payload 2" - - wg.Add(1) - go func() { - defer wg.Done() - if _, err := ca.Write([]byte(payload1)); err != nil { - t.Error(err) + for _, me := range msgsExpected { + sz := make([]byte, 2) + if _, err := io.ReadFull(tcb, sz); err != nil { + t.Fatal(err) + } + bSent := make([]byte, int(sz[0])<<8|int(sz[1])) + if _, err := io.ReadFull(tcb, bSent); err != nil { + t.Fatal(err) + } + m := &msg.Message{} + if err := proto.Unmarshal(bSent, m); err != nil { + t.Fatalf("unmarshal failed: %v", err) + } + if !proto.Equal(me, m) { + t.Errorf("message differs, expected: %v, got: %v", me, m) + } } - }() - // Check source to destination - msgsExpected := []*msg.Message{ - { - Type: msg.Message_STREAM_START, - StreamId: 1, - }, - { + // Check destination to source + msg := &msg.Message{ Type: msg.Message_DATA, StreamId: 1, - Payload: []byte(payload1), - }, - } - for _, me := range msgsExpected { - sz := make([]byte, 2) - if _, err := io.ReadFull(tcb, sz); err != nil { + Payload: []byte(payload2), + } + b, err := proto.Marshal(msg) + if err != nil { t.Fatal(err) } - bSent := make([]byte, int(sz[0])<<8|int(sz[1])) - if _, err := io.ReadFull(tcb, bSent); err != nil { + l := len(b) + _, err = tcb.Write(append( + []byte{byte(l >> 8), byte(l)}, + b..., + )) + if err != nil { t.Fatal(err) } - m := &msg.Message{} - if err := proto.Unmarshal(bSent, m); err != nil { - t.Fatalf("unmarshal failed: %v", err) + + bRecv := make([]byte, 100) + n, err := ca.Read(bRecv) + if err != nil { + t.Fatal(err) } - if !proto.Equal(me, m) { - t.Errorf("message differs, expected: %v, got: %v", me, m) + if string(bRecv[:n]) != payload2 { + t.Errorf("payload differs, expected: %s, got: %s", payload2, string(bRecv[:n])) } - } - // Check destination to source - msg := &msg.Message{ - Type: msg.Message_DATA, - StreamId: 1, - Payload: []byte(payload2), - } - b, err := proto.Marshal(msg) - if err != nil { - t.Fatal(err) - } - l := len(b) - _, err = tcb.Write(append( - []byte{byte(l >> 8), byte(l)}, - b..., - )) - if err != nil { - t.Fatal(err) - } - - bRecv := make([]byte, 100) - n, err := ca.Read(bRecv) - if err != nil { - t.Fatal(err) - } - if string(bRecv[:n]) != payload2 { - t.Errorf("payload differs, expected: %s, got: %s", payload2, string(bRecv[:n])) + // Check EOF + if err := tcb.Close(); err != nil { + t.Fatal(err) + } } + t.Run("WithoutStat", func(t *testing.T) { + test(t, nil) + }) + t.Run("WithStat", func(t *testing.T) { + stat := NewStat() + test(t, stat) - // Check EOF - if err := tcb.Close(); err != nil { - t.Fatal(err) - } + s := stat.Statistics() + expected := Statistics{ + NumConn: 1, + } + if !reflect.DeepEqual(expected, s) { + t.Errorf("Expected stat: %+v, got: %+v", expected, s) + } + }) }) t.Run("AcceptError", func(t *testing.T) { tca, tcb := net.Pipe() @@ -407,6 +450,7 @@ func TestProxySource(t *testing.T) { ErrorHandlerFunc(func(err error) { chErr <- err }), + nil, ); err != nil { t.Error(err) } @@ -471,7 +515,7 @@ func TestProxySource(t *testing.T) { wg.Wait() return nil, errConnect }), - nil, + nil, nil, ) var ie *ioterr.Error @@ -500,6 +544,7 @@ func TestProxySource(t *testing.T) { return nil, errConnect }), ErrorHandlerFunc(func(err error) { chErr <- err }), + nil, ); err != nil { t.Error(err) } @@ -556,6 +601,7 @@ func TestProxySource(t *testing.T) { return cb, nil }), ErrorHandlerFunc(func(err error) { chErr <- err }), + nil, ); err != nil { t.Error(err) } @@ -627,3 +673,15 @@ func TestProxyOption_validate(t *testing.T) { } }) } + +type errorConn struct { + net.Conn + err atomic.Value +} + +func (c *errorConn) Write(b []byte) (int, error) { + if err, ok := c.err.Load().(error); err != nil && ok { + return 0, err + } + return c.Conn.Write(b) +} diff --git a/tunnel/source.go b/tunnel/source.go index 323fb527..abce3baf 100644 --- a/tunnel/source.go +++ b/tunnel/source.go @@ -25,9 +25,21 @@ import ( "github.com/seqsense/aws-iot-device-sdk-go/v6/tunnel/msg" ) -func proxySource(ws io.ReadWriter, listener net.Listener, eh ErrorHandler) error { +func proxySource(ws io.ReadWriter, listener net.Listener, eh ErrorHandler, stat Stat) error { var muConns sync.Mutex conns := make(map[int32]io.ReadWriteCloser) + + updateStat := func() { + if stat != nil { + muConns.Lock() + n := len(conns) + muConns.Unlock() + stat.Update(func(stat *Statistics) { + stat.NumConn = n + }) + } + } + go func() { var streamID int32 = 1 for { @@ -55,7 +67,16 @@ func proxySource(ws io.ReadWriter, listener net.Listener, eh ErrorHandler) error continue } - go readProxy(ws, conn, id, eh) + go func() { + readProxy(ws, conn, id, eh) + muConns.Lock() + if conn, ok := conns[id]; ok { + _ = conn.Close() + delete(conns, id) + } + muConns.Unlock() + updateStat() + }() } }() @@ -114,5 +135,6 @@ func proxySource(ws io.ReadWriter, listener net.Listener, eh ErrorHandler) error } } } + updateStat() } } diff --git a/tunnel/stat.go b/tunnel/stat.go new file mode 100644 index 00000000..31f53e3c --- /dev/null +++ b/tunnel/stat.go @@ -0,0 +1,53 @@ +// Copyright 2021 SEQSENSE, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tunnel + +import ( + "sync" +) + +// Stat is an interface to get and update the statistics of the proxy. +// All methods are thread safe. +type Stat interface { + Statistics() Statistics + Update(func(*Statistics)) +} + +// Statistics stores proxy statistics data. +type Statistics struct { + NumConn int +} + +// NewStat creates new Stat. +func NewStat() Stat { + return &stat{} +} + +type stat struct { + stat Statistics + mu sync.RWMutex +} + +func (s *stat) Statistics() Statistics { + s.mu.RLock() + defer s.mu.RUnlock() + return s.stat +} + +func (s *stat) Update(fn func(*Statistics)) { + s.mu.Lock() + fn(&s.stat) + s.mu.Unlock() +} diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index b809147d..e392f8d0 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -47,6 +47,9 @@ type Options struct { EndpointHostFunc func(region string) string // TopicFunc is a function returns MQTT topic for the operation. TopicFunc func(operation string) string + + // ProxyOptions stores slice of ProxyOptions for each service. + ProxyOptions map[string][]ProxyOption } // Option is a type of functional options. @@ -68,6 +71,7 @@ func New(ctx context.Context, cli awsiotdev.Device, dialer map[string]Dialer, op t.opts = &Options{ TopicFunc: t.topic, EndpointHostFunc: endpointHost, + ProxyOptions: make(map[string][]ProxyOption), } for _, o := range opts { if err := o(t.opts); err != nil { @@ -101,11 +105,15 @@ func (t *tunnel) notify(msg *mqtt.Message) { for _, srv := range n.Services { if d, ok := t.dialerMap[srv]; ok { go func() { + opts := append( + []ProxyOption{WithErrorHandler(ErrorHandlerFunc(t.handleError))}, + t.opts.ProxyOptions[srv]..., + ) err := ProxyDestination( d, t.opts.EndpointHostFunc(n.Region), n.ClientAccessToken, - WithErrorHandler(ErrorHandlerFunc(t.handleError)), + opts..., ) if err != nil { t.handleError(ioterr.New(err, "creating proxy destination"))