Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it possible to test gateway opening/closing in Connect #14135

Merged
merged 11 commits into from
Jul 7, 2022
1 change: 1 addition & 0 deletions lib/srv/alpnproxy/local_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ func createAWSAccessProxySuite(t *testing.T, cred *credentials.Credentials) *Loc
t.Cleanup(func() {
err := lp.Close()
require.NoError(t, err)
hs.Close()
})
go func() {
err := lp.StartAWSAccessProxy(context.Background())
Expand Down
24 changes: 18 additions & 6 deletions lib/teleterm/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,11 @@ func (s *Service) createGateway(ctx context.Context, params CreateGatewayParams)
return nil, trace.Wrap(err)
}

gateway.Open()
go func() {
if err := gateway.Serve(); err != nil {
gateway.Log.WithError(err).Warn("Failed to open a connection.")
}
}()

s.gateways = append(s.gateways, gateway)

Expand Down Expand Up @@ -202,22 +206,28 @@ func (s *Service) RemoveGateway(ctx context.Context, gatewayURI string) error {
s.mu.Lock()
defer s.mu.Unlock()

s.removeGateway(gateway)
if err := s.removeGateway(gateway); err != nil {
return trace.Wrap(err)
}

return nil
}

// removeGateway assumes that mu is already held by a public method.
func (s *Service) removeGateway(gateway *gateway.Gateway) {
gateway.Close()
func (s *Service) removeGateway(gateway *gateway.Gateway) error {
if err := gateway.Close(); err != nil {
espadolini marked this conversation as resolved.
Show resolved Hide resolved
return trace.Wrap(err)
}

// remove closed gateway from list
for index := range s.gateways {
if s.gateways[index] == gateway {
s.gateways = append(s.gateways[:index], s.gateways[index+1:]...)
espadolini marked this conversation as resolved.
Show resolved Hide resolved
return
return nil
}
}

return trace.NotFound("gateway %v not found in gateway list", gateway.URI.String())
}

// RestartGateway stops a gateway and starts a new one with identical parameters.
Expand All @@ -232,7 +242,9 @@ func (s *Service) RestartGateway(ctx context.Context, gatewayURI string) error {
s.mu.Lock()
defer s.mu.Unlock()

s.removeGateway(gateway)
if err := s.removeGateway(gateway); err != nil {
return trace.Wrap(err)
}

newGateway, err := s.createGateway(ctx, CreateGatewayParams{
TargetURI: gateway.TargetURI,
Expand Down
28 changes: 17 additions & 11 deletions lib/teleterm/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,21 +104,27 @@ func New(cfg Config, cliCommandProvider CLICommandProvider) (*Gateway, error) {
}

// Close terminates gateway connection
func (g *Gateway) Close() {
func (g *Gateway) Close() error {
g.closeCancel()
g.localProxy.Close()

if err := g.localProxy.Close(); err != nil {
return trace.Wrap(err)
}

return nil
}

// Open opens a gateway to Teleport proxy
func (g *Gateway) Open() {
go func() {
g.Log.Info("Gateway is open.")
if err := g.localProxy.Start(g.closeContext); err != nil {
g.Log.WithError(err).Warn("Failed to open a connection.")
}
// Serve starts the underlying ALPN proxy. Blocks until closeContext is canceled.
func (g *Gateway) Serve() error {
g.Log.Info("Gateway is open.")

g.Log.Info("Gateway has closed.")
}()
if err := g.localProxy.Start(g.closeContext); err != nil {
return trace.Wrap(err)
}

g.Log.Info("Gateway has closed.")

return nil
}

// LocalPortInt returns the port of a gateway as an integer rather than a string.
Expand Down
44 changes: 44 additions & 0 deletions lib/teleterm/gateway/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@ package gateway

import (
"fmt"
"net"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/teleterm/api/uri"

"github.com/stretchr/testify/require"
)
Expand All @@ -45,3 +50,42 @@ func TestCLICommandUsesCLICommandProvider(t *testing.T) {

require.Equal(t, "foo/bar", command)
}

func TestGatewayStart(t *testing.T) {
hs := httptest.NewTLSServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {}))
t.Cleanup(func() {
hs.Close()
})

gateway, err := New(
Config{
TargetName: "foo",
TargetURI: uri.NewClusterURI("bar").AppendDB("foo").String(),
TargetUser: "alice",
Protocol: defaults.ProtocolPostgres,
CertPath: "../../../fixtures/certs/proxy1.pem",
KeyPath: "../../../fixtures/certs/proxy1-key.pem",
Insecure: true,
WebProxyAddr: hs.Listener.Addr().String(),
},
mockCLICommandProvider{},
)
require.NoError(t, err)

serveErr := make(chan error)

go func() {
err := gateway.Serve()
serveErr <- err
}()

// Dial to make sure gateway is open.
gatewayAddress := net.JoinHostPort(gateway.LocalAddress, gateway.LocalPort)
conn, err := net.DialTimeout("tcp", gatewayAddress, time.Second*1)
require.NoError(t, err)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@espadolini One thing that worries me is that this check fails only if I don't start the listener on that address. It doesn't fail if I never actually start the local proxy or if I start it after 1.5 second.

It turns out that localProxy.Start(), unlike gRPC's Server.Serve(), doesn't error if it's called after Close().

So the test not failing if Serve() is called after Close() is okay I guess. But why doesn't the test fail if never start the local proxy?

Copy link
Member Author

@ravicious ravicious Jul 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to adapt this to teleterm_test.go to replace time.Sleep with Dial but it seems that I misunderstood how Dial works. It doesn't actually wait for the other side to accept the connection since, as you said, the listener buffers those connections before they're accepted by the server.

So I cannot really use Dial there to make sure the server is accepting connections.

Diff of teleterm_test.go
diff --git a/lib/teleterm/teleterm.go b/lib/teleterm/teleterm.go
index b5356fdc9..123a5be95 100644
--- a/lib/teleterm/teleterm.go
+++ b/lib/teleterm/teleterm.go
@@ -16,8 +16,10 @@ package teleterm
 
 import (
 	"context"
+	"fmt"
 	"os"
 	"os/signal"
+	"time"
 
 	"github.com/gravitational/teleport/lib/teleterm/apiserver"
 	"github.com/gravitational/teleport/lib/teleterm/clusters"
@@ -60,7 +62,10 @@ func Start(ctx context.Context, cfg Config) error {
 
 	serverAPIWait := make(chan error)
 	go func() {
+		time.Sleep(time.Second * 2)
+		fmt.Println("apiServer.Serve()")
 		err := apiServer.Serve()
+		fmt.Println("apiServer.Serve() returned")
 		serverAPIWait <- err
 	}()
 
diff --git a/lib/teleterm/teleterm_test.go b/lib/teleterm/teleterm_test.go
index 51be1b619..649bac66f 100644
--- a/lib/teleterm/teleterm_test.go
+++ b/lib/teleterm/teleterm_test.go
@@ -16,7 +16,11 @@ package teleterm_test
 
 import (
 	"context"
+	"errors"
 	"fmt"
+	"net"
+	"os"
+	"path/filepath"
 	"testing"
 	"time"
 
@@ -27,27 +31,38 @@ import (
 
 func TestStart(t *testing.T) {
 	homeDir := t.TempDir()
+	sockPath := filepath.Join(homeDir, "teleterm.sock")
+	addr := fmt.Sprintf("unix://%v", sockPath)
 	cfg := teleterm.Config{
-		Addr:    fmt.Sprintf("unix://%v/teleterm.sock", homeDir),
+		Addr:    addr,
 		HomeDir: fmt.Sprintf("%v/", homeDir),
 	}
 
 	ctx, cancel := context.WithCancel(context.Background())
 	defer cancel()
 
-	wait := make(chan error)
+	serveErr := make(chan error)
 	go func() {
 		err := teleterm.Start(ctx, cfg)
-		wait <- err
+		serveErr <- err
 	}()
 
-	defer func() {
-		// Make sure Start() is called.
-		time.Sleep(time.Millisecond * 500)
+	require.Eventually(t, func() bool {
+		_, err := os.Stat(sockPath)
 
-		// Stop the server.
-		cancel()
-		require.NoError(t, <-wait)
-	}()
+		return !errors.Is(err, os.ErrNotExist)
+	}, time.Millisecond*500, time.Millisecond*50)
+
+	// Dial to make sure Start was called.
+	fmt.Println("net.DialTimeout()")
+	conn, err := net.DialTimeout("unix", sockPath, time.Second*1)
+	fmt.Println("net.DialTimeout() returned")
+	require.NoError(t, err)
 
+	// Stop the server.
+	fmt.Println("cancel()")
+	cancel()
+	conn.Close()
+	fmt.Println("waiting for serveErr")
+	require.NoError(t, <-serveErr)
 }

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would conn.Read() block until the server accepts the connection?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I think conn.Read() does the trick. I just pushed 3e34e02 which fails if the connection isn't accepted within a second.

I was also able to improve teleterm tests in a similar way, I'll push those changes in a separate PR.

t.Cleanup(func() { conn.Close() })

err = gateway.Close()
require.NoError(t, err)
require.NoError(t, <-serveErr)
}