Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deflake TestAgentForward #13166

Merged
merged 4 commits into from
Jun 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions lib/auth/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ type TestAuthServerConfig struct {
ClusterNetworkingConfig types.ClusterNetworkingConfig
// Streamer allows a test to set its own audit events streamer.
Streamer events.Streamer
// AuditLog allows a test to configure its own audit log.
AuditLog events.IAuditLog
}

// CheckAndSetDefaults checks and sets defaults
Expand Down Expand Up @@ -201,16 +203,20 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) {
// Wrap backend in sanitizer like in production.
srv.Backend = backend.NewSanitizer(b)

localLog, err := events.NewAuditLog(events.AuditLogConfig{
DataDir: cfg.Dir,
ServerID: cfg.ClusterName,
Clock: cfg.Clock,
UploadHandler: events.NewMemoryUploader(),
})
if err != nil {
return nil, trace.Wrap(err)
if cfg.AuditLog != nil {
srv.AuditLog = cfg.AuditLog
} else {
localLog, err := events.NewAuditLog(events.AuditLogConfig{
DataDir: cfg.Dir,
ServerID: cfg.ClusterName,
Clock: cfg.Clock,
UploadHandler: events.NewMemoryUploader(),
})
if err != nil {
return nil, trace.Wrap(err)
}
srv.AuditLog = localLog
}
srv.AuditLog = localLog

srv.SessionServer, err = session.New(srv.Backend)
if err != nil {
Expand All @@ -221,7 +227,7 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) {
identity := local.NewIdentityService(srv.Backend)

emitter, err := events.NewCheckingEmitter(events.CheckingEmitterConfig{
Inner: localLog,
Inner: srv.AuditLog,
Clock: cfg.Clock,
})
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion lib/events/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func (cfg *ProtoStreamConfig) CheckAndSetDefaults() error {
return nil
}

// NewProtoStream uploads session recordings to the protobuf format.
// NewProtoStream uploads session recordings in the protobuf format.
//
// The individual session stream is represented by continuous globally
// ordered sequence of events serialized to binary protobuf format.
Expand Down
44 changes: 28 additions & 16 deletions lib/srv/regular/sshserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/native"
"github.com/gravitational/teleport/lib/bpf"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/limiter"
"github.com/gravitational/teleport/lib/pam"
restricted "github.com/gravitational/teleport/lib/restrictedsession"
Expand Down Expand Up @@ -157,6 +158,7 @@ func newCustomFixture(t *testing.T, mutateCfg func(*auth.TestServerConfig), sshO
},
})
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, nodeClient.Close()) })

nodeDir := t.TempDir()
serverOptions := []ServerOption{
Expand Down Expand Up @@ -197,7 +199,10 @@ func newCustomFixture(t *testing.T, mutateCfg func(*auth.TestServerConfig), sshO
require.NoError(t, err)
require.NoError(t, auth.CreateUploaderDir(nodeDir))
require.NoError(t, sshSrv.Start())
t.Cleanup(func() { require.NoError(t, sshSrv.Close()) })
t.Cleanup(func() {
require.NoError(t, sshSrv.Close())
sshSrv.Wait()
})

require.NoError(t, sshSrv.heartbeat.ForceSend(time.Second))

Expand Down Expand Up @@ -246,7 +251,6 @@ func newCustomFixture(t *testing.T, mutateCfg func(*auth.TestServerConfig), sshO

t.Cleanup(func() { f.ssh.assertCltClose(t, client.Close()) })
require.NoError(t, agent.ForwardToAgent(client, keyring))

return f
}

Expand Down Expand Up @@ -593,7 +597,9 @@ func TestOpenExecSessionSetsSession(t *testing.T) {
// TestAgentForward tests agent forwarding via unix sockets
func TestAgentForward(t *testing.T) {
t.Parallel()
f := newFixture(t)
f := newCustomFixture(t, func(cfg *auth.TestServerConfig) {
cfg.Auth.AuditLog = events.NewDiscardAuditLog()
})

ctx := context.Background()
roleName := services.RoleNameForUser(f.user)
Expand All @@ -605,29 +611,36 @@ func TestAgentForward(t *testing.T) {
err = f.testSrv.Auth().UpsertRole(ctx, role)
require.NoError(t, err)

// use a sync recording mode because the disk-based uploader
// that runs in the background introduces races with test cleanup
recConfig := types.DefaultSessionRecordingConfig()
recConfig.SetMode(types.RecordAtNodeSync)
err = f.testSrv.Auth().SetSessionRecordingConfig(ctx, recConfig)
require.NoError(t, err)

se, err := f.ssh.clt.NewSession()
require.NoError(t, err)
defer se.Close()
t.Cleanup(func() { se.Close() })

err = agent.RequestAgentForwarding(se)
require.NoError(t, err)

// prepare to send virtual "keyboard input" into the shell:
keyboard, err := se.StdinPipe()
require.NoError(t, err)
t.Cleanup(func() { keyboard.Close() })

// start interactive SSH session (new shell):
err = se.Shell()
require.NoError(t, err)

// create a temp file to collect the shell output into:
tmpFile, err := os.CreateTemp(os.TempDir(), "teleport-agent-forward-test")
tmpFile, err := os.CreateTemp(t.TempDir(), "teleport-agent-forward-test")
require.NoError(t, err)
tmpFile.Close()
defer os.Remove(tmpFile.Name())

// type 'printenv SSH_AUTH_SOCK > /path/to/tmp/file' into the session (dumping the value of SSH_AUTH_STOCK into the temp file)
_, err = keyboard.Write([]byte(fmt.Sprintf("printenv %v >> %s\n\r", teleport.SSHAuthSock, tmpFile.Name())))
_, err = fmt.Fprintf(keyboard, "printenv %v >> %s\n\r", teleport.SSHAuthSock, tmpFile.Name())
require.NoError(t, err)

// wait for the output
Expand All @@ -644,8 +657,8 @@ func TestAgentForward(t *testing.T) {
// try dialing the ssh agent socket:
file, err := net.Dial("unix", socketPath)
require.NoError(t, err)
clientAgent := agent.NewClient(file)

clientAgent := agent.NewClient(file)
signers, err := clientAgent.Signers()
require.NoError(t, err)

Expand All @@ -665,6 +678,7 @@ func TestAgentForward(t *testing.T) {
// sessions on the connection).
err = se.Close()
require.NoError(t, err)

// Pause to allow closure to propagate.
time.Sleep(150 * time.Millisecond)
_, err = net.Dial("unix", socketPath)
Expand All @@ -678,14 +692,12 @@ func TestAgentForward(t *testing.T) {

// clt must be nullified to prevent double-close during test cleanup
f.ssh.clt = nil
for i := 0; i < 4; i++ {
_, err = net.Dial("unix", socketPath)
if err != nil {
return
}
time.Sleep(50 * time.Millisecond)
}
require.FailNow(t, "expected socket to be closed, still could dial after 150 ms")
require.Eventually(t, func() bool {
_, err := net.Dial("unix", socketPath)
return err != nil
},
150*time.Millisecond, 50*time.Millisecond,
"expected socket to be closed, still could dial")
}

// TestX11Forward tests x11 forwarding via unix sockets
Expand Down