Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix SSH agent forwarding to agentless nodes #22567

Merged
merged 2 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/exp/slices"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
Expand Down Expand Up @@ -7333,8 +7334,15 @@ func testAgentlessConnection(t *testing.T, suite *integrationTestSuite) {
}, tc.Username)
require.NoError(t, err)

_, _, err = nodeClient.Client.Client.SendRequest("test-request", true, nil)
// forward SSH agent
sshClient := nodeClient.Client.Client
session, err := sshClient.NewSession()
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, session.Close())
})
require.NoError(t, agent.ForwardToAgent(sshClient, tc.LocalAgent()))
require.NoError(t, agent.RequestAgentForwarding(session))

require.NoError(t, nodeClient.Close())
}
Expand Down Expand Up @@ -7366,21 +7374,41 @@ func startSSHServer(t *testing.T, caPubKeys []ssh.PublicKey, hostKey ssh.Signer)

go func() {
nConn, err := lis.Accept()
require.NoError(t, err)
assert.NoError(t, err)
t.Cleanup(func() {
// the error is ignored here to avoid failing on net.ErrClosed
_ = nConn.Close()
})

conn, _, reqs, err := ssh.NewServerConn(nConn, &sshCfg)
require.NoError(t, err)
conn, channels, reqs, err := ssh.NewServerConn(nConn, &sshCfg)
assert.NoError(t, err)
t.Cleanup(func() {
// the error is ignored here to avoid failing on net.ErrClosed
_ = conn.Close()
})
go ssh.DiscardRequests(reqs)

req := <-reqs
require.NoError(t, req.Reply(true, nil))
var agentForwarded bool
for channelReq := range channels {
assert.Equal(t, "session", channelReq.ChannelType())
channel, reqs, err := channelReq.Accept()
assert.NoError(t, err)
t.Cleanup(func() {
// the error is ignored here to avoid failing on net.ErrClosed
_ = channel.Close()
capnspacehook marked this conversation as resolved.
Show resolved Hide resolved
})

require.NoError(t, conn.Close())
for req := range reqs {
if req.WantReply {
assert.NoError(t, req.Reply(true, nil))
}
if req.Type == sshutils.AgentForwardRequest {
agentForwarded = true
capnspacehook marked this conversation as resolved.
Show resolved Hide resolved
break
}
}
}
assert.True(t, agentForwarded)
}()

return lis.Addr().String()
Expand Down
1 change: 0 additions & 1 deletion lib/proxy/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ func (r *Router) DialHost(ctx context.Context, clientSrcAddr, clientDstAddr net.
if err != nil {
return nil, "", trace.Wrap(err)
}
// TODO(capnspacehook): remove when forwarding SSH agent to agentless node works
agentGetter = nil
}

Expand Down
55 changes: 29 additions & 26 deletions lib/srv/forward/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,10 @@ func (s *Server) GetLockWatcher() *services.LockWatcher {
}

func (s *Server) Serve() {
config := &ssh.ServerConfig{}
var (
succeeded bool
config = &ssh.ServerConfig{}
)

// Configure callback for user certificate authentication.
config.PublicKeyCallback = s.authHandlers.UserKeyAuth
Expand All @@ -538,15 +541,22 @@ func (s *Server) Serve() {
s.log.Debugf("Supported KEX algorithms: %q.", s.kexAlgorithms)
s.log.Debugf("Supported MAC algorithms: %q.", s.macAlgorithms)

sconn, chans, reqs, err := ssh.NewServerConn(s.serverConn, config)
if err != nil {
// close
defer func() {
if succeeded {
return
}

if s.userAgent != nil {
s.userAgent.Close()
}
s.targetConn.Close()
s.clientConn.Close()
s.serverConn.Close()
}()

sconn, chans, reqs, err := ssh.NewServerConn(s.serverConn, config)
if err != nil {
s.log.Errorf("Unable to create server connection: %v.", err)
return
}
Expand All @@ -558,13 +568,6 @@ func (s *Server) Serve() {
// Take connection and extract identity information for the user from it.
s.identityContext, err = s.authHandlers.CreateIdentityContext(sconn)
if err != nil {
if s.userAgent != nil {
s.userAgent.Close()
}
s.targetConn.Close()
s.clientConn.Close()
s.serverConn.Close()

s.log.Errorf("Unable to create server connection: %v.", err)
return
}
Expand All @@ -578,17 +581,12 @@ func (s *Server) Serve() {
s.rejectChannel(chans, err.Error())
sconn.Close()

if s.userAgent != nil {
s.userAgent.Close()
}
s.targetConn.Close()
s.clientConn.Close()
s.serverConn.Close()

s.log.Errorf("Unable to create remote connection: %v", err)
return
}

succeeded = true

// The keep-alive loop will keep pinging the remote server and after it has
// missed a certain number of keep-alive requests it will cancel the
// closeContext which signals the server to shutdown.
Expand Down Expand Up @@ -646,12 +644,12 @@ func (s *Server) newRemoteClient(ctx context.Context, systemLogin string) (*trac
var signers []ssh.Signer
if s.agentlessSigner != nil {
signers = []ssh.Signer{s.agentlessSigner}
} else if s.userAgent != nil {
s, err := s.userAgent.Signers()
} else {
var err error
signers, err = s.userAgent.Signers()
if err != nil {
return nil, trace.Wrap(err)
}
signers = s
}
authMethod := ssh.PublicKeysCallback(signersWithSHA1Fallback(signers))

Expand Down Expand Up @@ -1140,19 +1138,24 @@ func (s *Server) dispatch(ctx context.Context, ch ssh.Channel, req *ssh.Request,
}

func (s *Server) handleAgentForward(ch ssh.Channel, req *ssh.Request, ctx *srv.ServerContext) error {
// TODO(capnspacehook): remove once SSH agent forwarding issue is fixed
if s.userAgent == nil {
return trace.BadParameter("SSH agent is not set")
}

// Check if the user's RBAC role allows agent forwarding.
err := s.authHandlers.CheckAgentForward(ctx)
if err != nil {
return trace.Wrap(err)
}

// Route authentication requests to the agent that was forwarded to the proxy.
err = agent.ForwardToAgent(ctx.RemoteClient.Client, s.userAgent)
// If no agent was forwarded to the proxy, create one now.
userAgent := s.userAgent
if userAgent == nil {
ctx.ConnectionContext.SetForwardAgent(true)
userAgent, err = ctx.StartAgentChannel()
if err != nil {
return trace.Wrap(err)
}
}

err = agent.ForwardToAgent(ctx.RemoteClient.Client, userAgent)
if err != nil {
return trace.Wrap(err)
}
Expand Down
4 changes: 1 addition & 3 deletions lib/sshutils/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ func (a *agentChannel) Close() error {
func (c *ConnectionContext) StartAgentChannel() (teleagent.Agent, error) {
// refuse to start an agent if forwardAgent has not yet been set.
if !c.GetForwardAgent() {
// TODO(capnspacehook): update SSH agent in forwarding SSH server
// when connecting to agentless nodes
return nil, trace.AccessDenied("agent forwarding required in proxy recording mode")
return nil, trace.AccessDenied("agent forwarding has not been requested")
}
// open a agent channel to client
ch, _, err := c.ServerConn.OpenChannel(AuthAgentRequest, nil)
Expand Down