Skip to content

Commit

Permalink
Increment port if already in use (fixes #250)
Browse files Browse the repository at this point in the history
* increment port

* limit port

* gofmt

* remove debuggy print statements

* broken test

* 3001

* move logic and testing into the socket packages

* fix comment

Co-authored-by: EC2 Default User <ec2-user@ip-172-31-40-12.us-east-2.compute.internal>
  • Loading branch information
matthewmueller and EC2 Default User authored Sep 3, 2022
1 parent ebf81a4 commit b03e09a
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 1 deletion.
3 changes: 2 additions & 1 deletion internal/cli/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ func (c *Command) Run(ctx context.Context) (err error) {
// Listening on the web listener as soon as possible
webln := c.in.WebLn
if webln == nil {
webln, err = socket.Listen(c.Listen)
// Listen and increment if the port is already in use up to 10 times
webln, err = socket.ListenUp(c.Listen, 10)
if err != nil {
return err
}
Expand Down
41 changes: 41 additions & 0 deletions package/socket/socket.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@ package socket

import (
"context"
"errors"
"fmt"
"net"
"net/http"
"os"
"strconv"
"syscall"
"time"

"github.com/livebud/bud/internal/urlx"
)

// ErrAddrInUse occurs when a port is already in use
var ErrAddrInUse = syscall.EADDRINUSE

type Listener interface {
net.Listener
file
Expand Down Expand Up @@ -70,6 +76,41 @@ func Listen(path string) (Listener, error) {
return &listener{tcp}, nil
}

// ListenUp is similar to listen, but will increment the port number until it
// finds a free one or reaches the maximum number of attempts
func ListenUp(path string, attempts int) (Listener, error) {
ln, err := Listen(path)
if err != nil {
if !errors.Is(err, ErrAddrInUse) {
return nil, err
}
if attempts--; attempts >= 0 {
newPath, err := incrementPort(path)
if err != nil {
return nil, err
}
return ListenUp(newPath, attempts-1)
}
return nil, err
}
return ln, nil
}

// Takes a address and increments the port by 1
func incrementPort(path string) (string, error) {
url, err := urlx.Parse(path)
if err != nil {
return "", err
}
port, err := strconv.Atoi(url.Port())
if err != nil {
return "", err
}
port++
url.Host = url.Hostname() + ":" + strconv.Itoa(port)
return url.String(), nil
}

// Dial creates a connection to an address
func Dial(ctx context.Context, address string) (net.Conn, error) {
url, err := urlx.Parse(address)
Expand Down
47 changes: 47 additions & 0 deletions package/socket/socket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@ import (
"context"
"errors"
"io"
"net"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"time"

"github.com/livebud/bud/internal/urlx"

"github.com/livebud/bud/internal/is"
"github.com/livebud/bud/package/socket"
)
Expand Down Expand Up @@ -123,3 +127,46 @@ func TestUDSCleanup(t *testing.T) {
is.True(errors.Is(err, os.ErrNotExist))
is.Equal(stat, nil)
}

func TestListenUp(t *testing.T) {
is := is.New(t)
ln0, err := socket.Listen(":0")
is.NoErr(err)
defer ln0.Close()
ln1, err := socket.ListenUp(ln0.Addr().String(), 5)
is.NoErr(err)
defer ln1.Close()
priorURL, err := urlx.Parse(ln0.Addr().String())
is.NoErr(err)
priorPort, err := strconv.Atoi(priorURL.Port())
is.NoErr(err)
url, err := urlx.Parse(ln1.Addr().String())
is.NoErr(err)
port, err := strconv.Atoi(url.Port())
is.NoErr(err)
is.Equal(port, priorPort+1)
}

func TestListenMaxAttemptsReached(t *testing.T) {
is := is.New(t)
ln0, err := socket.Listen(":0")
is.NoErr(err)
defer ln0.Close()
// This one should work
ln1, err := socket.ListenUp(ln0.Addr().String(), 1)
is.NoErr(err)
defer ln1.Close()
// This one should fail because we're using ln0 as the base
ln2, err := socket.ListenUp(ln0.Addr().String(), 1)
is.True(errors.Is(err, socket.ErrAddrInUse))
is.Equal(ln2, nil)
}

func TestListenPortTooHigh(t *testing.T) {
is := is.New(t)
ln0, err := socket.Listen(":65536")
ae, ok := err.(*net.AddrError)
is.True(ok)
is.Equal(ae.Err, "invalid port")
is.Equal(ln0, nil)
}

0 comments on commit b03e09a

Please sign in to comment.