Skip to content

Commit

Permalink
feat: Support cli argument for host key (#992)
Browse files Browse the repository at this point in the history
* feat: Support cli argument for host key

Signed-off-by: Ce Gao <cegao@tensorchord.ai>

* fix: Update

Signed-off-by: Ce Gao <cegao@tensorchord.ai>

* fix: Remove comments

Signed-off-by: Ce Gao <cegao@tensorchord.ai>

Signed-off-by: Ce Gao <cegao@tensorchord.ai>
  • Loading branch information
gaocegege committed Oct 9, 2022
1 parent f249e4b commit b2a9018
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 35 deletions.
34 changes: 30 additions & 4 deletions cmd/envd-sshd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ package main

import (
"fmt"
"io/ioutil"
"os"

"github.com/cockroachdb/errors"
rawssh "github.com/gliderlabs/ssh"
"github.com/gliderlabs/ssh"
"github.com/sirupsen/logrus"
"github.com/urfave/cli/v2"
rawssh "golang.org/x/crypto/ssh"

"github.com/tensorchord/envd/pkg/config"
"github.com/tensorchord/envd/pkg/remote/sshd"
Expand All @@ -36,6 +38,7 @@ const (
flagNoAuth = "no-auth"
flagPort = "port"
flagShell = "shell"
flagHostKey = "hostkey"
)

func main() {
Expand All @@ -59,14 +62,20 @@ func main() {
EnvVars: []string{"ENVD_AUTHORIZED_KEYS_PATH"},
Aliases: []string{"a"},
},
&cli.StringFlag{
Name: flagHostKey,
Usage: "path to the host key",
EnvVars: []string{"ENVD_HOST_KEY"},
},
&cli.BoolFlag{
Name: flagNoAuth,
Usage: "disable authentication",
Value: false,
},
&cli.IntFlag{
Name: flagPort,
Usage: "port to listen on",
Name: flagPort,
Usage: "port to listen on",
Aliases: []string{"p"},
},
&cli.StringFlag{
Name: flagShell,
Expand Down Expand Up @@ -107,7 +116,7 @@ func sshServer(c *cli.Context) error {
}

noAuth := c.Bool(flagNoAuth)
var keys []rawssh.PublicKey
var keys []ssh.PublicKey
if !noAuth {
var err error
path := c.String(flagAuthKey)
Expand All @@ -125,10 +134,27 @@ func sshServer(c *cli.Context) error {
logrus.Warn("no authentication enabled")
}

var hostKey ssh.Signer = nil
if c.String(flagHostKey) != "" {
// read private key file
pemBytes, err := ioutil.ReadFile(c.String(flagHostKey))
if err != nil {
return errors.Wrapf(
err, "reading private key %s failed", c.String(flagHostKey))
}
if privateKey, err := rawssh.ParsePrivateKey(pemBytes); err != nil {
return err
} else {
logrus.Debugf("load host key from %s", c.String(flagHostKey))
hostKey = privateKey
}
}

srv := sshd.Server{
Port: port,
Shell: shell,
AuthorizedKeys: keys,
Hostkey: hostKey,
}

logrus.Infof("ssh server %s started in 0.0.0.0:%d", version.GetVersion().String(), srv.Port)
Expand Down
1 change: 1 addition & 0 deletions pkg/app/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ func sshc(clicontext *cli.Context) error {
opt.User = it
opt.PrivateKeyPath = clicontext.Path("private-key")
opt.Port = 2222
opt.AgentForwarding = false
sshClient, err := ssh.NewClient(opt)
if err != nil {
return errors.Wrap(err, "failed to create the ssh client")
Expand Down
10 changes: 8 additions & 2 deletions pkg/remote/sshd/sshd.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ func LoadAuthorizedKeys(path string) ([]ssh.PublicKey, error) {

// Server holds the ssh server configuration.
type Server struct {
Port int
Shell string
Port int
Shell string

AuthorizedKeys []ssh.PublicKey
Hostkey ssh.Signer
}

// ListenAndServe starts the SSH server using port
Expand Down Expand Up @@ -117,6 +119,10 @@ func (srv *Server) getServer() (*ssh.Server, error) {
server.PasswordHandler = nil
}

if srv.Hostkey != nil {
server.AddHostKey(srv.Hostkey)
}

return server, nil
}

Expand Down
70 changes: 41 additions & 29 deletions pkg/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,22 @@ type Client interface {
}

type Options struct {
Server string
User string
Port int
Auth bool
PrivateKeyPath string
PrivateKeyPwd string
AgentForwarding bool
Server string
User string
Port int
Auth bool
PrivateKeyPath string
PrivateKeyPwd string
}

func DefaultOptions() Options {
return Options{
Server: "localhost",
User: "envd",
Auth: true,
PrivateKeyPwd: "",
Server: "localhost",
User: "envd",
Auth: true,
PrivateKeyPwd: "",
AgentForwarding: true,
}
}

Expand All @@ -82,6 +84,7 @@ func GetOptions(entry string) (*Options, error) {

type generalClient struct {
cli *ssh.Client
opt *Options
}

func NewClient(opt Options) (Client, error) {
Expand Down Expand Up @@ -120,26 +123,31 @@ func NewClient(opt Options) (Client, error) {
}
cli = conn

// open connection to the local agent
socketLocation := os.Getenv("SSH_AUTH_SOCK")
if socketLocation != "" {
agentConn, err := net.Dial("unix", socketLocation)
if err != nil {
return nil, errors.Wrap(err, "could not connect to local agent socket")
}
// create agent and add in auth
forwardingAgent := agent.NewClient(agentConn)
// add callback for forwarding agent to SSH config
// XXX - might want to handle reconnects appending multiple callbacks
auth := ssh.PublicKeysCallback(forwardingAgent.Signers)
config.Auth = append(config.Auth, auth)
if err := agent.ForwardToAgent(cli, forwardingAgent); err != nil {
return nil, errors.Wrap(err, "forwarding agent to client failed")
if opt.AgentForwarding {
// open connection to the local agent
socketLocation := os.Getenv("SSH_AUTH_SOCK")
if socketLocation != "" {
agentConn, err := net.Dial("unix", socketLocation)
if err != nil {
return nil, errors.Wrap(err, "could not connect to local agent socket")
}
// create agent and add in auth
forwardingAgent := agent.NewClient(agentConn)
// add callback for forwarding agent to SSH config
// might want to handle reconnects appending multiple callbacks
auth := ssh.PublicKeysCallback(forwardingAgent.Signers)
config.Auth = append(config.Auth, auth)
if err := agent.ForwardToAgent(cli, forwardingAgent); err != nil {
return nil, errors.Wrap(err, "forwarding agent to client failed")
}
} else {
logrus.Warn("failed to get the environment variable SSH_AUTH_SOCK")
}
}

return &generalClient{
cli: cli,
opt: &opt,
}, nil
}

Expand All @@ -157,8 +165,10 @@ func (c generalClient) ExecWithOutput(cmd string) ([]byte, error) {
}
defer session.Close()

if err := agent.RequestAgentForwarding(session); err != nil {
return nil, errors.Wrap(err, "requesting agent forwarding failed")
if c.opt.AgentForwarding {
if err := agent.RequestAgentForwarding(session); err != nil {
return nil, errors.Wrap(err, "requesting agent forwarding failed")
}
}

return session.CombinedOutput(cmd)
Expand All @@ -172,8 +182,10 @@ func (c generalClient) Attach() error {
}
defer session.Close()

if err := agent.RequestAgentForwarding(session); err != nil {
return errors.Wrap(err, "requesting agent forwarding failed")
if c.opt.AgentForwarding {
if err := agent.RequestAgentForwarding(session); err != nil {
return errors.Wrap(err, "requesting agent forwarding failed")
}
}

modes := ssh.TerminalModes{
Expand Down

0 comments on commit b2a9018

Please sign in to comment.