diff --git a/tests/config.go b/tests/config.go index 77e8474..871fac9 100644 --- a/tests/config.go +++ b/tests/config.go @@ -4,14 +4,13 @@ const ( bytesToSend = 1024 bufLen = 100 numClients = 1 - listenerId = "Bob" dialerId = "Alice" seedHex = "e68e046d13dd911594576ba0f4a196e9666790dc492071ad9ea5972c0b940435" - remoteAddr = "Bob.be285ff9330122cea44487a9618f96603fde6d37d5909ae1c271616772c349fe" - toAddrTcp = "127.0.0.1:54321" - fromAddrTcp = "127.0.0.1:12345" - - toAddrUdp = "127.0.0.1:54321" - fromAddrUdp = "127.0.0.1:12346" + listenerId = "Bob1" + toPort = "127.0.0.1:54321" ) + +var fromPorts = []string{"127.0.0.1:12345"} +var fromUDPPorts = []string{"127.0.0.1:22345"} +var remoteAddrs = []string{"Bob1.be285ff9330122cea44487a9618f96603fde6d37d5909ae1c271616772c349fe"} diff --git a/tests/pub.go b/tests/pub.go index cf3ff17..1408025 100644 --- a/tests/pub.go +++ b/tests/pub.go @@ -4,6 +4,8 @@ import ( "encoding/hex" "fmt" "log" + "strings" + "time" nkn "github.com/nknorg/nkn-sdk-go" ts "github.com/nknorg/nkn-tuna-session" @@ -24,20 +26,20 @@ const ( tunnelClientIsReady = "tunnel client is ready" tcpServerIsReady = "tcp server is ready" udpServerIsReady = "udp server is ready" - tcpDialerExit = "tcp dialer exit" + exit = "exit" tcpServerExit = "tcp server exit" udpServerExit = "udp server exit" udpClientExit = "udp client exit" ) -var ch chan string = make(chan string, 4) +var ch chan string func waitFor(ch chan string, status string) { - fmt.Println("waiting for ", status) + fmt.Println("Waiting for:", status) for { str := <-ch - fmt.Println("waitFor got: ", str) - if status == str { + fmt.Println("Got:", str) + if strings.Contains(str, status) { break } } @@ -104,12 +106,9 @@ func CreateTunaSession(account *nkn.Account, wallet *nkn.Wallet, mc *nkn.MultiCl return } -var node *types.Node +var tunaNode *types.Node func CreateTunnelConfig(udp bool) *tunnel.Config { - if node == nil { - node = StartTunaNode() - } config := &tunnel.Config{ NumSubClients: numClients, ClientConfig: CreateClientConfig(3), @@ -118,7 +117,7 @@ func CreateTunnelConfig(udp bool) *tunnel.Config { TunaSessionConfig: CreateTunaSessionConfig(numClients), Verbose: true, UDP: udp, - TunaNode: node, + TunaNode: tunaNode, } return config @@ -151,6 +150,7 @@ func StartTunaNode() *types.Node { func runReverseEntry(seed []byte) error { entryAccount, err := vault.NewAccountWithSeed(seed) if err != nil { + fmt.Println("runReverseEntry vault.NewAccountWithSeed err ", err) return err } seedRPCServerAddr := nkn.NewStringArray(nkn.DefaultSeedRPCServerAddr...) @@ -160,19 +160,88 @@ func runReverseEntry(seed []byte) error { } entryWallet, err := nkn.NewWallet(&nkn.Account{Account: entryAccount}, walletConfig) if err != nil { + fmt.Println("runReverseEntry nkn.NewWallet err ", err) return err } entryConfig := new(tuna.EntryConfiguration) err = util.ReadJSON("config.reverse.entry.json", entryConfig) if err != nil { + fmt.Println("runReverseEntry util.ReadJSON err ", err) return err } err = tuna.StartReverse(entryConfig, entryWallet) if err != nil { + fmt.Println("runReverseEntry tuna.StartReverse err ", err) return err } ch <- tunaNodeStarted + return nil +} + +func StartTunnelListeners(tuna bool) error { + acc, _, err := CreateAccountAndWallet(seedHex) + if err != nil { + return err + } + + config := CreateTunnelConfig(tuna) + + tunnels, err := tunnel.NewTunnels(acc, listenerId, []string{"nkn"}, []string{toPort}, tuna, config) + if err != nil { + return err + } + time.Sleep(10 * time.Second) // wait for tuna node is ready + if tuna { + for _, t := range tunnels { + ts := t.TunaSessionClient() + <-ts.OnConnect() + ch <- tunaSessionConnected + } + } + ch <- tunnelServerIsReady + fmt.Printf("tunnel server is ready, toPort is %v\n", toPort) + + for _, t := range tunnels { + err = t.Start() + if err != nil { + return err + } + } + + return nil +} + +func StartTunnelDialers(tcp, tuna bool) error { + acc, _, err := CreateAccountAndWallet(seedHex) + if err != nil { + return err + } + + config := CreateTunnelConfig(tuna) + var from []string + if tcp { + from = fromPorts + } else { + from = fromUDPPorts + } + + tunnels, err := tunnel.NewTunnels(acc, dialerId, from, remoteAddrs, tuna, config) + if err != nil { + return err + } + + for _, t := range tunnels { + go func(t *tunnel.Tunnel) { + err := t.Start() + if err != nil { + fmt.Printf("tunnel.Start err: %v\n", err) + return + } + }(t) + } - select {} + time.Sleep(5 * time.Second) // Tunnel start time + ch <- tunnelClientIsReady + return nil } diff --git a/tests/tcp_test.go b/tests/tcp_test.go index 7830313..157a185 100644 --- a/tests/tcp_test.go +++ b/tests/tcp_test.go @@ -3,15 +3,21 @@ package tests import ( "fmt" "net" + "os" "strings" + "sync" "testing" "time" - - tunnel "github.com/nknorg/nkn-tunnel" ) -// go test -v -run=TestTCPWriteReadData -func TestTCPWriteReadData(t *testing.T) { +// go test -v -run=TestTCP +func TestTCP(t *testing.T) { + ch = make(chan string, 4) + if tunaNode == nil { + tunaNode = StartTunaNode() + waitFor(ch, tunaNodeStarted) + } + go func() { err := StartTcpServer() if err != nil { @@ -22,18 +28,19 @@ func TestTCPWriteReadData(t *testing.T) { waitFor(ch, tcpServerIsReady) tuna := true + go func() { - err := StartTunnelListener(toAddrTcp, tuna) + err := StartTunnelListeners(tuna) if err != nil { - fmt.Printf("StartTunnelListener err: %v\n", err) - return + fmt.Printf("StartTunnelListeners err: %v\n", err) + os.Exit(-1) } }() waitFor(ch, tunnelServerIsReady) go func() { - err := StartTunnelDialer(fromAddrTcp, tuna) + err := StartTunnelDialers(true, tuna) if err != nil { fmt.Printf("StartTunnelDialer err: %v\n", err) return @@ -42,64 +49,18 @@ func TestTCPWriteReadData(t *testing.T) { waitFor(ch, tunnelClientIsReady) - go StartTcpDialer() + go StartTcpDialers() waitFor(ch, tcpServerExit) -} - -func StartTunnelListener(toAddr string, tuna bool) error { - acc, _, err := CreateAccountAndWallet(seedHex) - if err != nil { - return err - } - - config := CreateTunnelConfig(true) - tun, err := tunnel.NewTunnel(acc, listenerId, "nkn", toAddr, tuna, config) - if err != nil { - return err - } - time.Sleep(10 * time.Second) // wait for tuna node is ready - - ts := tun.TunaSessionClient() - <-ts.OnConnect() - ch <- tunaSessionConnected - ch <- tunnelServerIsReady - fmt.Printf("tunnel server is ready, toAddr is %v\n", toAddr) - - err = tun.Start() - if err != nil { - return err - } - return nil -} - -func StartTunnelDialer(fromAddr string, tuna bool) error { - acc, _, err := CreateAccountAndWallet(seedHex) - if err != nil { - return err - } - - config := CreateTunnelConfig(true) - tun, err := tunnel.NewTunnel(acc, listenerId, fromAddr, remoteAddr, tuna, config) - if err != nil { - return err - } - - ch <- tunnelClientIsReady - - err = tun.Start() - if err != nil { - return err - } - return nil + close(ch) } func StartTcpServer() error { - listener, err := net.Listen("tcp", toAddrTcp) + listener, err := net.Listen("tcp", toPort) if err != nil { return err } - fmt.Printf("tcp server is listening at %v\n", toAddrTcp) + fmt.Printf("StartTcpServer is listening at %v\n", toPort) ch <- tcpServerIsReady conn, err := listener.Accept() @@ -111,33 +72,60 @@ func StartTcpServer() error { for { n, err := conn.Read(b) if err != nil { + fmt.Printf("StartTcpServer conn.Read err %v\n", err) return err } - fmt.Printf("tcp server read: %v\n", string(b[:n])) - if strings.Contains(string(b[:n]), tcpDialerExit) { + fmt.Printf("TCP Server got: %v\n", string(b[:n])) + if strings.Contains(string(b[:n]), exit) { break } + // echo + _, err = conn.Write(b[:n]) + if err != nil { + fmt.Printf("StartTcpServer conn.Write err %v\n", err) + return err + } } ch <- tcpServerExit return nil } -func StartTcpDialer() error { - conn, err := net.Dial("tcp", fromAddrTcp) - if err != nil { - return err +func StartTcpDialers() error { + var wg sync.WaitGroup + for i, fromPort := range fromPorts { + wg.Add(1) + go func(clientNum int, from string) { + defer wg.Done() + conn, err := net.Dial("tcp", from) + if err != nil { + fmt.Printf("StartTcpDialers net.Dial to %v err %v\n", from, err) + return + } + + for i := 0; i < 10; i++ { + msg := fmt.Sprintf("tcp client %v data %v", clientNum, i) + _, err := conn.Write([]byte(msg)) + if err != nil { + fmt.Printf("StartTcpDialers conn.Write to %v err %v\n", from, err) + return + } + b := make([]byte, 1024) + n, err := conn.Read(b) + if err != nil { + fmt.Printf("StartTcpDialers conn.Read to %v err %v\n", from, err) + return + } + if string(b[:n]) != msg { + fmt.Printf("StartTcpDialers get echo %v, it should be %v\n", string(b[:n]), msg) + return + } + fmt.Printf("TCP Client %v got echo: %v\n", clientNum, string(b[:n])) + } + conn.Write([]byte(exit)) + time.Sleep(2 * time.Second) // wait for tcp server get it + }(i, fromPort) } - - for i := 0; i < 10; i++ { - _, err := conn.Write([]byte(fmt.Sprintf("tcp client data %v\n", i))) - if err != nil { - return err - } - } - conn.Write([]byte(tcpDialerExit)) - time.Sleep(2 * time.Second) // wait for tcp server get it - - ch <- tcpDialerExit + wg.Wait() return nil } diff --git a/tests/udp_test.go b/tests/udp_test.go index 7ea2e06..7355865 100644 --- a/tests/udp_test.go +++ b/tests/udp_test.go @@ -4,12 +4,19 @@ import ( "fmt" "net" "strings" + "sync" "testing" "time" ) // go test -v -run=TestUDP func TestUDP(t *testing.T) { + ch = make(chan string, 4) + if tunaNode == nil { + tunaNode = StartTunaNode() + waitFor(ch, tunaNodeStarted) + } + go func() { err := StartUdpServer() if err != nil { @@ -21,7 +28,7 @@ func TestUDP(t *testing.T) { waitFor(ch, udpServerIsReady) tuna := true go func() { - err := StartTunnelListener(toAddrUdp, tuna) + err := StartTunnelListeners(tuna) if err != nil { fmt.Printf("StartTunnelListener err: %v\n", err) return @@ -31,7 +38,7 @@ func TestUDP(t *testing.T) { waitFor(ch, tunnelServerIsReady) go func() { - err := StartTunnelDialer(fromAddrUdp, tuna) + err := StartTunnelDialers(false, tuna) if err != nil { fmt.Printf("StartTunnelDialer err: %v\n", err) return @@ -40,40 +47,46 @@ func TestUDP(t *testing.T) { waitFor(ch, tunnelClientIsReady) - for i := 0; i < 2; i++ { - go func(clientNum int) { - err := StartUdpClient(clientNum) - if err != nil { - fmt.Printf("StartTunnelDialer %v err: %v\n", clientNum, err) - return - } - }(i) + err := StartUdpClients() + if err != nil { + fmt.Printf("StartUdpClients err: %v\n", err) + return } waitFor(ch, udpServerExit) + close(ch) } func StartUdpServer() error { - a, err := net.ResolveUDPAddr("udp", toAddrUdp) + a, err := net.ResolveUDPAddr("udp", toPort) if err != nil { + fmt.Println("StartUdpServer ResolveUDPAddr err: ", err) return err } udpServer, err := net.ListenUDP("udp", a) if err != nil { + fmt.Println("StartUdpServer ListenUDP err: ", err) return err } + fmt.Printf("udp server is listening at %v\n", toPort) ch <- udpServerIsReady b := make([]byte, 1024) for { n, addr, err := udpServer.ReadFromUDP(b) if err != nil { + fmt.Println("StartUdpServer ReadFromUDP err: ", err) return err } fmt.Printf("UDP Server got: %v\n", string(b[:n])) time.Sleep(500 * time.Millisecond) - udpServer.WriteTo(b[:n], addr) - if strings.Contains(string(b[:n]), udpClientExit) { + n, err = udpServer.WriteTo(b[:n], addr) + if err != nil { + fmt.Printf("udpServer WriteTo err %v\n", err) + break + } + if strings.Contains(string(b[:n]), exit) { + fmt.Println("Udp Server got exit, exit now.") break } } @@ -82,39 +95,55 @@ func StartUdpServer() error { return nil } -func StartUdpClient(clientNo int) error { - a, err := net.ResolveUDPAddr("udp", fromAddrUdp) - if err != nil { - return err - } - udpClient, err := net.DialUDP("udp", nil, a) - if err != nil { - return err - } +func StartUdpClients() error { + var wg sync.WaitGroup + for i, fromPort := range fromUDPPorts { + wg.Add(1) + go func(clientNum int, from string) { + defer wg.Done() + fmt.Printf("upd client %v send to port %v\n", clientNum, from) - for j := 0; j < 10; j++ { - sendData := fmt.Sprintf("udp client %v am at %v", clientNo, j) - n, _, err := udpClient.WriteMsgUDP([]byte(sendData), nil, nil) - if err != nil { - fmt.Printf("StartUdpClient WriteMsgUDP err: %v\n", err) - return err - } + a, err := net.ResolveUDPAddr("udp", from) + if err != nil { + fmt.Println("StartUdpClient net.ResolveUDPAddr err: ", err) + return + } + udpClient, err := net.DialUDP("udp", nil, a) + if err != nil { + fmt.Println("StartUdpClient net.DialUDP err: ", err) + return + } - recvData := make([]byte, 1024) - n, _, err = udpClient.ReadFrom(recvData) - if err != nil { - fmt.Printf("StartUdpClient.ReadFrom err %v\n", err) - return err - } + for j := 0; j < 10; j++ { + sendData := fmt.Sprintf("udp client %v msg %v", clientNum, j) + _, _, err := udpClient.WriteMsgUDP([]byte(sendData), nil, nil) + if err != nil { + fmt.Printf("StartUdpClient WriteMsgUDP err: %v\n", err) + return + } + + recvData := make([]byte, 1024) + n, _, err := udpClient.ReadFrom(recvData) + if err != nil { + fmt.Printf("StartUdpClient.ReadFrom err %v\n", err) + return + } + + if string(recvData[:n]) != sendData { + fmt.Printf("udpClient.ReadFrom is not equal to I sent.\n") + fmt.Printf("I sent %v, received: %v\n", sendData, string(recvData[:n])) + } else { + fmt.Printf("UDP Client %v got echo: %v\n", clientNum, string(recvData[:n])) + } + time.Sleep(100 * time.Millisecond) + } + udpClient.WriteMsgUDP([]byte(exit), nil, nil) + time.Sleep(time.Second) + + }(i, fromPort) - if string(recvData[:n]) != sendData { - fmt.Printf("udpClient.ReadFrom is not equal to I sent.\n") - fmt.Printf("I sent %v, received: %v\n", sendData, string(recvData[:n])) - } - time.Sleep(1 * time.Second) } - udpClient.WriteMsgUDP([]byte(udpClientExit), nil, nil) - time.Sleep(time.Second) - ch <- udpClientExit + + wg.Wait() return nil } diff --git a/tunnel.go b/tunnel.go index f393dd2..0e0c90d 100644 --- a/tunnel.go +++ b/tunnel.go @@ -1,6 +1,7 @@ package tunnel import ( + "errors" "io" "log" "net" @@ -48,6 +49,20 @@ type Tunnel struct { // NewTunnel creates a Tunnel client with given options. func NewTunnel(account *nkn.Account, identifier, from, to string, tuna bool, config *Config) (*Tunnel, error) { + tunnels, err := NewTunnels(account, identifier, []string{from}, []string{to}, tuna, config) + if err != nil { + return nil, err + } + + return tunnels[0], nil +} + +// NewTunnels creates Tunnel clients with given options. +func NewTunnels(account *nkn.Account, identifier string, from, to []string, tuna bool, config *Config) ([]*Tunnel, error) { + if len(from) != len(to) || len(from) == 0 { + return nil, errors.New("from should have same length as to") + } + config, err := MergedConfig(config) if err != nil { return nil, err @@ -56,85 +71,98 @@ func NewTunnel(account *nkn.Account, identifier, from, to string, tuna bool, con return nil, ErrUDPNotSupported } - fromNKN := len(from) == 0 || strings.ToLower(from) == "nkn" - toNKN := !strings.Contains(to, ":") + udpConnExpired := cache.NoExpiration + if config.UDPIdleTime > 0 { + udpConnExpired = time.Duration(config.UDPIdleTime) * time.Second + } + + fromNKN := false + for _, f := range from { + fromNKN = (len(f) == 0 || strings.ToLower(f) == "nkn") + if fromNKN { + break + } + } + if fromNKN && len(from) > 1 { + return nil, errors.New("multiple tunnels is not supported when from NKN") + } + var m *nkn.MultiClient var c *ts.TunaSessionClient var dialer nknDialer - if fromNKN || toNKN { - m, err = nkn.NewMultiClient(account, identifier, config.NumSubClients, config.OriginalClient, config.ClientConfig) + m, err = nkn.NewMultiClient(account, identifier, config.NumSubClients, config.OriginalClient, config.ClientConfig) + if err != nil { + return nil, err + } + + <-m.OnConnect.C + dialer = newMultiClientDialer(m) + + if tuna { + wallet, err := nkn.NewWallet(account, config.WalletConfig) if err != nil { return nil, err } - <-m.OnConnect.C + c, err = ts.NewTunaSessionClient(account, m, wallet, config.TunaSessionConfig) + if err != nil { + return nil, err + } - dialer = newMultiClientDialer(m) + dialer = c + } - if tuna { - wallet, err := nkn.NewWallet(account, config.WalletConfig) - if err != nil { - return nil, err + tunnels := make([]*Tunnel, 0) + for i, f := range from { + toNKN := !strings.Contains(to[i], ":") + listeners := make([]net.Listener, 0, 2) + + if fromNKN { + if tuna { + if config.TunaNode != nil { + c.SetTunaNode(config.TunaNode) + } + listeners = append(listeners, c) + err = c.Listen(config.AcceptAddrs) + if err != nil { + return nil, err + } } - c, err = ts.NewTunaSessionClient(account, m, wallet, config.TunaSessionConfig) + listeners = append(listeners, m) + err = m.Listen(config.AcceptAddrs) if err != nil { return nil, err } - dialer = c - } - } - - listeners := make([]net.Listener, 0, 2) - - if fromNKN { - if tuna { - if config.TunaNode != nil { - c.SetTunaNode(config.TunaNode) - } - listeners = append(listeners, c) - err = c.Listen(config.AcceptAddrs) + f = m.Addr().String() + } else { + listener, err := net.Listen("tcp", f) if err != nil { return nil, err } + listeners = append(listeners, listener) } - listeners = append(listeners, m) - err = m.Listen(config.AcceptAddrs) - if err != nil { - return nil, err - } - from = m.Addr().String() - } else { - listener, err := net.Listen("tcp", from) - if err != nil { - return nil, err - } - listeners = append(listeners, listener) - } - - log.Println("Listening at", from) - udpConnExpired := cache.NoExpiration - if config.UDPIdleTime > 0 { - udpConnExpired = time.Duration(config.UDPIdleTime) * time.Second - } - - t := &Tunnel{ - from: from, - to: to, - fromNKN: fromNKN, - toNKN: toNKN, - config: config, - dialer: dialer, - listeners: listeners, - multiClient: m, - tsClient: c, - udpConnCache: cache.New(udpConnExpired, udpConnExpired), + log.Println("Listening at", f) + + t := &Tunnel{ + from: f, + to: to[i], + fromNKN: fromNKN, + toNKN: toNKN, + config: config, + dialer: dialer, + listeners: listeners, + multiClient: m, + tsClient: c, + udpConnCache: cache.New(udpConnExpired, udpConnExpired), + } + tunnels = append(tunnels, t) } - return t, nil + return tunnels, nil } // FromAddr returns the tunnel listening address. diff --git a/udp.go b/udp.go index 4b8968e..e040ceb 100644 --- a/udp.go +++ b/udp.go @@ -93,6 +93,7 @@ func (t *Tunnel) listenUDP() (udpConn, error) { if err != nil { return nil, err } + log.Println("Tunnel is listening at UDP", a.String()) } return fromUDPConn, nil