Skip to content

Commit

Permalink
make TestGetHostPortRange unit test deterministic
Browse files Browse the repository at this point in the history
  • Loading branch information
singholt committed Feb 22, 2023
1 parent cff9c85 commit 193a8e5
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 27 deletions.
68 changes: 41 additions & 27 deletions agent/utils/ephemeral_ports.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"time"

"github.com/docker/go-connections/nat"
"github.com/pkg/errors"
)

// From https://www.kernel.org/doc/html/latest//networking/ip-sysctl.html#ip-variables
Expand Down Expand Up @@ -132,34 +133,12 @@ func GetHostPortRange(numberOfPorts int, protocol string, dynamicHostPortRange s
func getHostPortRange(numberOfPorts, start, end int, protocol string) (string, int, error) {
var resultStartPort, resultEndPort, n int
for port := start; port <= end; port++ {
portStr := strconv.Itoa(port)
// check if port is available
if protocol == "tcp" {
// net.Listen announces on the local tcp network
ln, err := net.Listen(protocol, ":"+portStr)
// either port is unavailable or some error occurred while listening, we proceed to the next port
if err != nil {
continue
}
// let's close the listener first
err = ln.Close()
if err != nil {
continue
}
} else if protocol == "udp" {
// net.ListenPacket announces on the local udp network
ln, err := net.ListenPacket(protocol, ":"+portStr)
// either port is unavailable or some error occurred while listening, we proceed to the next port
if err != nil {
continue
}
// let's close the listener first
err = ln.Close()
if err != nil {
continue
}
isPortAvailable, err := isPortAvailableFunc(port, protocol)
if !isPortAvailable || err != nil {
// either port is unavailable or some error occurred while listening or closing the listener,
// we proceed to the next port
continue
}

// check if current port is contiguous relative to lastPort
if port-resultEndPort != 1 {
resultStartPort = port
Expand All @@ -182,3 +161,38 @@ func getHostPortRange(numberOfPorts, start, end int, protocol string) (string, i

return fmt.Sprintf("%d-%d", resultStartPort, resultEndPort), resultEndPort, nil
}

var isPortAvailableFunc = isPortAvailable

// isPortAvailable checks if a port is available
func isPortAvailable(port int, protocol string) (bool, error) {
portStr := strconv.Itoa(port)
switch protocol {
case "tcp":
// net.Listen announces on the local tcp network
ln, err := net.Listen(protocol, ":"+portStr)
if err != nil {
return false, err
}
// let's close the listener first
err = ln.Close()
if err != nil {
return false, err
}
return true, nil
case "udp":
// net.ListenPacket announces on the local udp network
ln, err := net.ListenPacket(protocol, ":"+portStr)
if err != nil {
return true, err
}
// let's close the listener first
err = ln.Close()
if err != nil {
return false, err
}
return true, ln.Close()
default:
return false, errors.New("invalid protocol")
}
}
10 changes: 10 additions & 0 deletions agent/utils/ephemeral_ports_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,16 @@ func TestGetHostPortRange(t *testing.T) {
},
}

// mock isPortAvailable() for unit test
// this ensures that the test doesn't rely on the runtime port availability on the host
isPortAvailableFuncTmp := isPortAvailableFunc
defer func() {
isPortAvailableFunc = isPortAvailableFuncTmp
}()
isPortAvailableFunc = func(port int, protocol string) error {
return nil
}

for _, tc := range testCases {
t.Run(tc.testName, func(t *testing.T) {
for i := 0; i < tc.numberOfRequests; i++ {
Expand Down

0 comments on commit 193a8e5

Please sign in to comment.