Skip to content

Commit

Permalink
require instead of assert; fix missing close connectons in one test
Browse files Browse the repository at this point in the history
  • Loading branch information
cre4ture committed Sep 15, 2024
1 parent 84d1a26 commit ad2dbdc
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 40 deletions.
68 changes: 37 additions & 31 deletions port_forwarder/port_forwarder_tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/slackhq/nebula/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func startReadToChannel(receiverConn net.Conn) <-chan []byte {
Expand Down Expand Up @@ -44,7 +45,7 @@ func doTestTcpCommunication(
fmt.Println("sending ...")
t.Log("sending ...")
n, err = senderConn.Write(data_sent)
assert.Nil(t, err)
require.Nil(t, err)
assert.Equal(t, n, len(data_sent))

fmt.Println("receiving ...")
Expand All @@ -57,7 +58,7 @@ func doTestTcpCommunication(
}
fmt.Println("DONE")
t.Log("DONE")
assert.Nil(t, err)
require.Nil(t, err)
assert.Equal(t, n, len(data_sent))
assert.Equal(t, data_sent, buf[:n])
}
Expand All @@ -73,7 +74,7 @@ func doTestTcpCommunicationFail(
if err != nil {
return
}
assert.Nil(t, err)
require.Nil(t, err)
assert.Equal(t, n, len(data_sent))

buf := make([]byte, 100)
Expand All @@ -89,7 +90,7 @@ func tcpListenerNAccept(t *testing.T, listener *net.TCPListener, n int) <-chan n
r <- true
for range n {
conn, err := listener.Accept()
assert.Nil(t, err)
require.Nil(t, err)
c <- conn
}
}()
Expand All @@ -110,7 +111,7 @@ port_forwarding:
dial_address: 127.0.0.1:5595
protocols: [tcp]
`)
assert.Nil(t, err)
require.Nil(t, err)

assert.Len(t, server_pf.portForwardings, 1)

Expand All @@ -121,27 +122,32 @@ port_forwarding:
dial_address: 10.0.0.1:4495
protocols: [tcp]
`)
assert.Nil(t, err)
require.Nil(t, err)

assert.Len(t, client_pf.portForwardings, 1)

client_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3395")
assert.Nil(t, err)
require.Nil(t, err)
server_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5595")
assert.Nil(t, err)
require.Nil(t, err)

server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr)
assert.Nil(t, err)
require.Nil(t, err)
defer server_listen_conn.Close()
server_listen_conn_accepts := tcpListenerNAccept(t, server_listen_conn, 2)

client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr)
assert.Nil(t, err)
require.Nil(t, err)
defer client1_conn.Close()

client1_rcv_chan := startReadToChannel(client1_conn)
client1_server_side_conn := <-server_listen_conn_accepts
client1_server_side_rcv_chan := startReadToChannel(client1_server_side_conn)

client2_conn, err := net.DialTCP("tcp", nil, client_conn_addr)
assert.Nil(t, err)
require.Nil(t, err)
defer client2_conn.Close()

client2_rcv_chan := startReadToChannel(client2_conn)
client2_server_side_conn := <-server_listen_conn_accepts
client2_server_side_rcv_chan := startReadToChannel(client2_server_side_conn)
Expand Down Expand Up @@ -173,7 +179,7 @@ port_forwarding:
dial_address: 127.0.0.1:5597
protocols: [tcp]
`)
assert.Nil(t, err)
require.Nil(t, err)

assert.Len(t, server_pf.portForwardings, 1)

Expand All @@ -186,27 +192,27 @@ port_forwarding:
dial_address: 10.0.0.1:4497
protocols: [tcp]
`)
assert.Nil(t, err)
require.Nil(t, err)

assert.Len(t, client_pf.portForwardings, 1)

time.Sleep(100 * time.Millisecond)

client_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3397")
assert.Nil(t, err)
require.Nil(t, err)
server_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5597")
assert.Nil(t, err)
require.Nil(t, err)

server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr)
assert.Nil(t, err)
require.Nil(t, err)
defer server_listen_conn.Close()

server_listen_conn_accepts := tcpListenerNAccept(t, server_listen_conn, 1)

time.Sleep(100 * time.Millisecond)

client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr)
assert.Nil(t, err)
require.Nil(t, err)
defer client1_conn.Close()
client1_rcv_chan := startReadToChannel(client1_conn)

Expand All @@ -232,7 +238,7 @@ port_forwarding:
dial_address: 127.0.0.1:5596
protocols: [tcp]
`)
assert.Nil(t, err)
require.Nil(t, err)

assert.Len(t, server_pf.portForwardings, 1)

Expand All @@ -243,10 +249,10 @@ port_forwarding:
dial_address: 10.0.0.1:4496
protocols: [tcp]
`)
assert.Nil(t, err)
require.Nil(t, err)

err = client_pf.ApplyChangesByNewFwdList(new_client_fwd_list)
assert.Nil(t, err)
require.Nil(t, err)

doTestTcpCommunicationFail(t, "Hello from client 1 side!",
client1_conn, client1_server_side_conn)
Expand All @@ -255,7 +261,7 @@ port_forwarding:
client1_server_side_conn, client1_conn)

err = server_pf.ApplyChangesByNewFwdList(new_server_fwd_list)
assert.Nil(t, err)
require.Nil(t, err)

doTestTcpCommunicationFail(t, "Hello from client 1 side!",
client1_conn, client1_server_side_conn)
Expand All @@ -274,7 +280,7 @@ port_forwarding:
dial_address: 127.0.0.1:5599
protocols: [tcp]
`)
assert.Nil(t, err)
require.Nil(t, err)

assert.Len(t, server_pf.portForwardings, 1)

Expand All @@ -285,22 +291,22 @@ port_forwarding:
dial_address: 10.0.0.1:4499
protocols: [tcp]
`)
assert.Nil(t, err)
require.Nil(t, err)

assert.Len(t, client_pf.portForwardings, 1)

client_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3399")
assert.Nil(t, err)
require.Nil(t, err)
server_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5599")
assert.Nil(t, err)
require.Nil(t, err)

server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr)
assert.Nil(t, err)
require.Nil(t, err)
defer server_listen_conn.Close()
server_listen_conn_accepts := tcpListenerNAccept(t, server_listen_conn, 1)

client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr)
assert.Nil(t, err)
require.Nil(t, err)
defer client1_conn.Close()
client1_rcv_chan := startReadToChannel(client1_conn)

Expand All @@ -326,7 +332,7 @@ port_forwarding:
dial_address: 127.0.0.1:5598
protocols: [tcp]
`)
assert.Nil(t, err)
require.Nil(t, err)

assert.Len(t, server_pf.portForwardings, 1)

Expand All @@ -337,10 +343,10 @@ port_forwarding:
dial_address: 10.0.0.1:4498
protocols: [tcp]
`)
assert.Nil(t, err)
require.Nil(t, err)

err = server_pf.ApplyChangesByNewFwdList(new_server_fwd_list)
assert.Nil(t, err)
require.Nil(t, err)

doTestTcpCommunicationFail(t, "Hello from client 1 side!",
client1_conn, client1_server_side_conn)
Expand All @@ -349,7 +355,7 @@ port_forwarding:
client1_server_side_conn, client1_conn)

err = client_pf.ApplyChangesByNewFwdList(new_client_fwd_list)
assert.Nil(t, err)
require.Nil(t, err)

doTestTcpCommunicationFail(t, "Hello from client 1 side!",
client1_conn, client1_server_side_conn)
Expand Down
19 changes: 10 additions & 9 deletions port_forwarder/port_forwarder_udp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/service"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func loadPortFwdConfigFromString(l *logrus.Logger, configStr string) (*PortForwardingList, error) {
Expand Down Expand Up @@ -65,11 +66,11 @@ func doTestUdpCommunication(
} else {
n, err = senderConn.Write(data_sent)
}
assert.Nil(t, err)
require.Nil(t, err)
assert.Equal(t, n, len(data_sent))

pair := <-receiverConn
assert.Nil(t, err)
require.Nil(t, err)
assert.Equal(t, data_sent, pair.a)
return pair.b
}
Expand Down Expand Up @@ -107,7 +108,7 @@ port_forwarding:
dial_address: 127.0.0.1:5599
protocols: [udp]
`)
assert.Nil(t, err)
require.Nil(t, err)

assert.Len(t, server_pf.portForwardings, 1)

Expand All @@ -118,27 +119,27 @@ port_forwarding:
dial_address: 10.0.0.1:4499
protocols: [udp]
`)
assert.Nil(t, err)
require.Nil(t, err)

assert.Len(t, client_pf.portForwardings, 1)

client_conn_addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3399")
assert.Nil(t, err)
require.Nil(t, err)
server_conn_addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5599")
assert.Nil(t, err)
require.Nil(t, err)

server_listen_conn, err := net.ListenUDP("udp", server_conn_addr)
assert.Nil(t, err)
require.Nil(t, err)
defer server_listen_conn.Close()
server_listen_rcv_chan := readUdpConnectionToChannel(server_listen_conn)

client1_conn, err := net.DialUDP("udp", nil, client_conn_addr)
assert.Nil(t, err)
require.Nil(t, err)
defer client1_conn.Close()
client1_rcv_chan := readUdpConnectionToChannel(client1_conn)

client2_conn, err := net.DialUDP("udp", nil, client_conn_addr)
assert.Nil(t, err)
require.Nil(t, err)
defer client2_conn.Close()
client2_rcv_chan := readUdpConnectionToChannel(client2_conn)

Expand Down

0 comments on commit ad2dbdc

Please sign in to comment.