diff --git a/pkg/lang/ir/graph.go b/pkg/lang/ir/graph.go index fc7930d4a..ac5bea30a 100644 --- a/pkg/lang/ir/graph.go +++ b/pkg/lang/ir/graph.go @@ -42,6 +42,7 @@ func NewGraph() *Graph { // They are used by vscode remote. "curl", "openssh-client", + "git", }, PyPIPackages: []string{}, diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index c9878d2ac..ef1da500f 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -32,6 +32,7 @@ import ( "github.com/sirupsen/logrus" "github.com/tensorchord/envd/pkg/lang/ir" "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/agent" "golang.org/x/term" ) @@ -40,8 +41,7 @@ type Client interface { } type generalClient struct { - config *ssh.ClientConfig - server string + cli *ssh.Client } func NewClient(server, user string, @@ -54,6 +54,8 @@ func NewClient(server, user string, }, } + var cli *ssh.Client + if auth { // read private key file pemBytes, err := ioutil.ReadFile(privateKeyPath) @@ -70,27 +72,51 @@ func NewClient(server, user string, } } + host := fmt.Sprintf("%s:%d", server, port) + // open connection + conn, err := ssh.Dial("tcp", host, config) + if err != nil { + return nil, errors.Wrap(err, "dialing failed") + } + cli = conn + + // open connection to the local agent + socketLocation := os.Getenv("SSH_AUTH_SOCK") + if socketLocation != "" { + agentConn, err := net.Dial("unix", socketLocation) + if err != nil { + return nil, errors.Wrap(err, "could not connect to local agent socket") + } + // create agent and add in auth + forwardingAgent := agent.NewClient(agentConn) + // add callback for forwarding agent to SSH config + // XXX - might want to handle reconnects appending multiple callbacks + auth := ssh.PublicKeysCallback(forwardingAgent.Signers) + config.Auth = append(config.Auth, auth) + if err := agent.ForwardToAgent(cli, forwardingAgent); err != nil { + return nil, errors.Wrap(err, "forwarding agent to client failed") + } + } + return &generalClient{ - config: config, - server: fmt.Sprintf("%v:%v", server, port), + cli: cli, }, nil } func (c generalClient) Attach() error { - // open connection - conn, err := ssh.Dial("tcp", c.server, c.config) - if err != nil { - return fmt.Errorf("dial to %v failed %v", c.server, err) - } - defer conn.Close() + defer c.cli.Close() // open session - session, err := conn.NewSession() + session, err := c.cli.NewSession() if err != nil { - return fmt.Errorf("create session for %v failed %v", c.server, err) + return errors.Wrap(err, "creating session failed") } defer session.Close() + if err := agent.RequestAgentForwarding(session); err != nil { + return errors.Wrap(err, "requesting agent forwarding failed") + } + modes := ssh.TerminalModes{ ssh.ECHO: 0, // Disable echoing ssh.ECHOCTL: 0, // Don't print control chars