diff --git a/integration/integration_test.go b/integration/integration_test.go index c3aa88075f6f1..e24826eddecec 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -49,6 +49,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" "golang.org/x/exp/slices" "google.golang.org/grpc" "google.golang.org/grpc/connectivity" @@ -7333,8 +7334,15 @@ func testAgentlessConnection(t *testing.T, suite *integrationTestSuite) { }, tc.Username) require.NoError(t, err) - _, _, err = nodeClient.Client.Client.SendRequest("test-request", true, nil) + // forward SSH agent + sshClient := nodeClient.Client.Client + session, err := sshClient.NewSession() require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, session.Close()) + }) + require.NoError(t, agent.ForwardToAgent(sshClient, tc.LocalAgent())) + require.NoError(t, agent.RequestAgentForwarding(session)) require.NoError(t, nodeClient.Close()) } @@ -7366,21 +7374,41 @@ func startSSHServer(t *testing.T, caPubKeys []ssh.PublicKey, hostKey ssh.Signer) go func() { nConn, err := lis.Accept() - require.NoError(t, err) + assert.NoError(t, err) t.Cleanup(func() { + // the error is ignored here to avoid failing on net.ErrClosed _ = nConn.Close() }) - conn, _, reqs, err := ssh.NewServerConn(nConn, &sshCfg) - require.NoError(t, err) + conn, channels, reqs, err := ssh.NewServerConn(nConn, &sshCfg) + assert.NoError(t, err) t.Cleanup(func() { + // the error is ignored here to avoid failing on net.ErrClosed _ = conn.Close() }) + go ssh.DiscardRequests(reqs) - req := <-reqs - require.NoError(t, req.Reply(true, nil)) + var agentForwarded bool + for channelReq := range channels { + assert.Equal(t, "session", channelReq.ChannelType()) + channel, reqs, err := channelReq.Accept() + assert.NoError(t, err) + t.Cleanup(func() { + // the error is ignored here to avoid failing on net.ErrClosed + _ = channel.Close() + }) - require.NoError(t, conn.Close()) + for req := range reqs { + if req.WantReply { + assert.NoError(t, req.Reply(true, nil)) + } + if req.Type == sshutils.AgentForwardRequest { + agentForwarded = true + break + } + } + } + assert.True(t, agentForwarded) }() return lis.Addr().String() diff --git a/lib/proxy/router.go b/lib/proxy/router.go index 73f3c83b6f07b..3427bd95d483e 100644 --- a/lib/proxy/router.go +++ b/lib/proxy/router.go @@ -285,7 +285,6 @@ func (r *Router) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net. if err != nil { return nil, "", trace.Wrap(err) } - // TODO(capnspacehook): remove when forwarding SSH agent to agentless node works agentGetter = nil } diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 40e5d4d649732..8e255323078e2 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -514,7 +514,10 @@ func (s *Server) GetLockWatcher() *services.LockWatcher { } func (s *Server) Serve() { - config := &ssh.ServerConfig{} + var ( + succeeded bool + config = &ssh.ServerConfig{} + ) // Configure callback for user certificate authentication. config.PublicKeyCallback = s.authHandlers.UserKeyAuth @@ -538,15 +541,22 @@ func (s *Server) Serve() { s.log.Debugf("Supported KEX algorithms: %q.", s.kexAlgorithms) s.log.Debugf("Supported MAC algorithms: %q.", s.macAlgorithms) - sconn, chans, reqs, err := ssh.NewServerConn(s.serverConn, config) - if err != nil { + // close + defer func() { + if succeeded { + return + } + if s.userAgent != nil { s.userAgent.Close() } s.targetConn.Close() s.clientConn.Close() s.serverConn.Close() + }() + sconn, chans, reqs, err := ssh.NewServerConn(s.serverConn, config) + if err != nil { s.log.Errorf("Unable to create server connection: %v.", err) return } @@ -558,13 +568,6 @@ func (s *Server) Serve() { // Take connection and extract identity information for the user from it. s.identityContext, err = s.authHandlers.CreateIdentityContext(sconn) if err != nil { - if s.userAgent != nil { - s.userAgent.Close() - } - s.targetConn.Close() - s.clientConn.Close() - s.serverConn.Close() - s.log.Errorf("Unable to create server connection: %v.", err) return } @@ -578,17 +581,12 @@ func (s *Server) Serve() { s.rejectChannel(chans, err.Error()) sconn.Close() - if s.userAgent != nil { - s.userAgent.Close() - } - s.targetConn.Close() - s.clientConn.Close() - s.serverConn.Close() - s.log.Errorf("Unable to create remote connection: %v", err) return } + succeeded = true + // The keep-alive loop will keep pinging the remote server and after it has // missed a certain number of keep-alive requests it will cancel the // closeContext which signals the server to shutdown. @@ -646,12 +644,12 @@ func (s *Server) newRemoteClient(ctx context.Context, systemLogin string) (*trac var signers []ssh.Signer if s.agentlessSigner != nil { signers = []ssh.Signer{s.agentlessSigner} - } else if s.userAgent != nil { - s, err := s.userAgent.Signers() + } else { + var err error + signers, err = s.userAgent.Signers() if err != nil { return nil, trace.Wrap(err) } - signers = s } authMethod := ssh.PublicKeysCallback(signersWithSHA1Fallback(signers)) @@ -1140,11 +1138,6 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request, } func (s *Server) handleAgentForward(ch ssh.Channel, req *ssh.Request, ctx *srv.ServerContext) error { - // TODO(capnspacehook): remove once SSH agent forwarding issue is fixed - if s.userAgent == nil { - return trace.BadParameter("SSH agent is not set") - } - // Check if the user's RBAC role allows agent forwarding. err := s.authHandlers.CheckAgentForward(ctx) if err != nil { @@ -1152,7 +1145,17 @@ func (s *Server) handleAgentForward(ch ssh.Channel, req *ssh.Request, ctx *srv.S } // Route authentication requests to the agent that was forwarded to the proxy. - err = agent.ForwardToAgent(ctx.RemoteClient.Client, s.userAgent) + // If no agent was forwarded to the proxy, create one now. + userAgent := s.userAgent + if userAgent == nil { + ctx.ConnectionContext.SetForwardAgent(true) + userAgent, err = ctx.StartAgentChannel() + if err != nil { + return trace.Wrap(err) + } + } + + err = agent.ForwardToAgent(ctx.RemoteClient.Client, userAgent) if err != nil { return trace.Wrap(err) } diff --git a/lib/sshutils/ctx.go b/lib/sshutils/ctx.go index 922fb88ca5219..4a2d41414d1a4 100644 --- a/lib/sshutils/ctx.go +++ b/lib/sshutils/ctx.go @@ -132,9 +132,7 @@ func (a *agentChannel) Close() error { func (c *ConnectionContext) StartAgentChannel() (teleagent.Agent, error) { // refuse to start an agent if forwardAgent has not yet been set. if !c.GetForwardAgent() { - // TODO(capnspacehook): update SSH agent in forwarding SSH server - // when connecting to agentless nodes - return nil, trace.AccessDenied("agent forwarding required in proxy recording mode") + return nil, trace.AccessDenied("agent forwarding has not been requested") } // open a agent channel to client ch, _, err := c.ServerConn.OpenChannel(AuthAgentRequest, nil)