From b2a90188c012fa186de2ecb7b1aa534062681020 Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Sun, 9 Oct 2022 20:31:48 +0800 Subject: [PATCH] feat: Support cli argument for host key (#992) * feat: Support cli argument for host key Signed-off-by: Ce Gao * fix: Update Signed-off-by: Ce Gao * fix: Remove comments Signed-off-by: Ce Gao Signed-off-by: Ce Gao --- cmd/envd-sshd/main.go | 34 +++++++++++++++++--- pkg/app/ssh.go | 1 + pkg/remote/sshd/sshd.go | 10 ++++-- pkg/ssh/ssh.go | 70 ++++++++++++++++++++++++----------------- 4 files changed, 80 insertions(+), 35 deletions(-) diff --git a/cmd/envd-sshd/main.go b/cmd/envd-sshd/main.go index 168332625..6e12d2f26 100644 --- a/cmd/envd-sshd/main.go +++ b/cmd/envd-sshd/main.go @@ -18,12 +18,14 @@ package main import ( "fmt" + "io/ioutil" "os" "github.com/cockroachdb/errors" - rawssh "github.com/gliderlabs/ssh" + "github.com/gliderlabs/ssh" "github.com/sirupsen/logrus" "github.com/urfave/cli/v2" + rawssh "golang.org/x/crypto/ssh" "github.com/tensorchord/envd/pkg/config" "github.com/tensorchord/envd/pkg/remote/sshd" @@ -36,6 +38,7 @@ const ( flagNoAuth = "no-auth" flagPort = "port" flagShell = "shell" + flagHostKey = "hostkey" ) func main() { @@ -59,14 +62,20 @@ func main() { EnvVars: []string{"ENVD_AUTHORIZED_KEYS_PATH"}, Aliases: []string{"a"}, }, + &cli.StringFlag{ + Name: flagHostKey, + Usage: "path to the host key", + EnvVars: []string{"ENVD_HOST_KEY"}, + }, &cli.BoolFlag{ Name: flagNoAuth, Usage: "disable authentication", Value: false, }, &cli.IntFlag{ - Name: flagPort, - Usage: "port to listen on", + Name: flagPort, + Usage: "port to listen on", + Aliases: []string{"p"}, }, &cli.StringFlag{ Name: flagShell, @@ -107,7 +116,7 @@ func sshServer(c *cli.Context) error { } noAuth := c.Bool(flagNoAuth) - var keys []rawssh.PublicKey + var keys []ssh.PublicKey if !noAuth { var err error path := c.String(flagAuthKey) @@ -125,10 +134,27 @@ func sshServer(c *cli.Context) error { logrus.Warn("no authentication enabled") } + var hostKey ssh.Signer = nil + if c.String(flagHostKey) != "" { + // read private key file + pemBytes, err := ioutil.ReadFile(c.String(flagHostKey)) + if err != nil { + return errors.Wrapf( + err, "reading private key %s failed", c.String(flagHostKey)) + } + if privateKey, err := rawssh.ParsePrivateKey(pemBytes); err != nil { + return err + } else { + logrus.Debugf("load host key from %s", c.String(flagHostKey)) + hostKey = privateKey + } + } + srv := sshd.Server{ Port: port, Shell: shell, AuthorizedKeys: keys, + Hostkey: hostKey, } logrus.Infof("ssh server %s started in 0.0.0.0:%d", version.GetVersion().String(), srv.Port) diff --git a/pkg/app/ssh.go b/pkg/app/ssh.go index d7bac4f41..9f3562ffe 100644 --- a/pkg/app/ssh.go +++ b/pkg/app/ssh.go @@ -58,6 +58,7 @@ func sshc(clicontext *cli.Context) error { opt.User = it opt.PrivateKeyPath = clicontext.Path("private-key") opt.Port = 2222 + opt.AgentForwarding = false sshClient, err := ssh.NewClient(opt) if err != nil { return errors.Wrap(err, "failed to create the ssh client") diff --git a/pkg/remote/sshd/sshd.go b/pkg/remote/sshd/sshd.go index b6018c808..269f52b15 100644 --- a/pkg/remote/sshd/sshd.go +++ b/pkg/remote/sshd/sshd.go @@ -68,9 +68,11 @@ func LoadAuthorizedKeys(path string) ([]ssh.PublicKey, error) { // Server holds the ssh server configuration. type Server struct { - Port int - Shell string + Port int + Shell string + AuthorizedKeys []ssh.PublicKey + Hostkey ssh.Signer } // ListenAndServe starts the SSH server using port @@ -117,6 +119,10 @@ func (srv *Server) getServer() (*ssh.Server, error) { server.PasswordHandler = nil } + if srv.Hostkey != nil { + server.AddHostKey(srv.Hostkey) + } + return server, nil } diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index 8188cd8fc..8bd3ac77e 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -47,20 +47,22 @@ type Client interface { } type Options struct { - Server string - User string - Port int - Auth bool - PrivateKeyPath string - PrivateKeyPwd string + AgentForwarding bool + Server string + User string + Port int + Auth bool + PrivateKeyPath string + PrivateKeyPwd string } func DefaultOptions() Options { return Options{ - Server: "localhost", - User: "envd", - Auth: true, - PrivateKeyPwd: "", + Server: "localhost", + User: "envd", + Auth: true, + PrivateKeyPwd: "", + AgentForwarding: true, } } @@ -82,6 +84,7 @@ func GetOptions(entry string) (*Options, error) { type generalClient struct { cli *ssh.Client + opt *Options } func NewClient(opt Options) (Client, error) { @@ -120,26 +123,31 @@ func NewClient(opt Options) (Client, error) { } cli = conn - // open connection to the local agent - socketLocation := os.Getenv("SSH_AUTH_SOCK") - if socketLocation != "" { - agentConn, err := net.Dial("unix", socketLocation) - if err != nil { - return nil, errors.Wrap(err, "could not connect to local agent socket") - } - // create agent and add in auth - forwardingAgent := agent.NewClient(agentConn) - // add callback for forwarding agent to SSH config - // XXX - might want to handle reconnects appending multiple callbacks - auth := ssh.PublicKeysCallback(forwardingAgent.Signers) - config.Auth = append(config.Auth, auth) - if err := agent.ForwardToAgent(cli, forwardingAgent); err != nil { - return nil, errors.Wrap(err, "forwarding agent to client failed") + if opt.AgentForwarding { + // open connection to the local agent + socketLocation := os.Getenv("SSH_AUTH_SOCK") + if socketLocation != "" { + agentConn, err := net.Dial("unix", socketLocation) + if err != nil { + return nil, errors.Wrap(err, "could not connect to local agent socket") + } + // create agent and add in auth + forwardingAgent := agent.NewClient(agentConn) + // add callback for forwarding agent to SSH config + // might want to handle reconnects appending multiple callbacks + auth := ssh.PublicKeysCallback(forwardingAgent.Signers) + config.Auth = append(config.Auth, auth) + if err := agent.ForwardToAgent(cli, forwardingAgent); err != nil { + return nil, errors.Wrap(err, "forwarding agent to client failed") + } + } else { + logrus.Warn("failed to get the environment variable SSH_AUTH_SOCK") } } return &generalClient{ cli: cli, + opt: &opt, }, nil } @@ -157,8 +165,10 @@ func (c generalClient) ExecWithOutput(cmd string) ([]byte, error) { } defer session.Close() - if err := agent.RequestAgentForwarding(session); err != nil { - return nil, errors.Wrap(err, "requesting agent forwarding failed") + if c.opt.AgentForwarding { + if err := agent.RequestAgentForwarding(session); err != nil { + return nil, errors.Wrap(err, "requesting agent forwarding failed") + } } return session.CombinedOutput(cmd) @@ -172,8 +182,10 @@ func (c generalClient) Attach() error { } defer session.Close() - if err := agent.RequestAgentForwarding(session); err != nil { - return errors.Wrap(err, "requesting agent forwarding failed") + if c.opt.AgentForwarding { + if err := agent.RequestAgentForwarding(session); err != nil { + return errors.Wrap(err, "requesting agent forwarding failed") + } } modes := ssh.TerminalModes{