From 0fc4d8d12e9a13c00e760ee105765b0270a2c3ee Mon Sep 17 00:00:00 2001 From: Ce Gao Date: Thu, 13 Oct 2022 15:21:52 +0800 Subject: [PATCH] feat(context): Get ssh hostname from context, instead of hard-coded string (#1020) * fix: Use envd-server in the context Signed-off-by: Ce Gao * feat: Add ssh address discovery Signed-off-by: Ce Gao * fix: Remove unused struct Signed-off-by: Ce Gao Signed-off-by: Ce Gao --- docs/proposals/20220603-kubernetes-vendor.md | 4 +- pkg/app/create.go | 9 +++- pkg/app/up.go | 19 +++++--- pkg/envd/factory.go | 20 +++++++-- pkg/ssh/ssh.go | 1 - pkg/types/envd.go | 14 ++++++ pkg/util/netutil/netutil.go | 20 ++++++++- pkg/util/netutil/netutil_test.go | 46 ++++++++++++++++++++ 8 files changed, 118 insertions(+), 15 deletions(-) diff --git a/docs/proposals/20220603-kubernetes-vendor.md b/docs/proposals/20220603-kubernetes-vendor.md index 6ba14e77e..d0e6a7979 100644 --- a/docs/proposals/20220603-kubernetes-vendor.md +++ b/docs/proposals/20220603-kubernetes-vendor.md @@ -63,7 +63,7 @@ Users may use `envd` to build the image, and use it on Kubernetes. Thus they nee The end-to-end process will be: ``` -$ envd context create --name test --builder-name test --use --builder kube-pod --runner server --runner-addr http://localhost:2222 +$ envd context create --name test --builder-name test --use --builder kube-pod --runner server --runner-addr http://localhost:8080 $ envd login $ envd build $ envd push @@ -71,7 +71,7 @@ $ envd run --env test --image test or -$ envd context create --name test --builder-name test --use --builder kube-pod --runner server --runner-addr http://localhost:2222 +$ envd context create --name test --builder-name test --use --builder kube-pod --runner server --runner-addr http://localhost:8080 $ envd login $ envd up ``` diff --git a/pkg/app/create.go b/pkg/app/create.go index 131e37bfb..b69704fc1 100644 --- a/pkg/app/create.go +++ b/pkg/app/create.go @@ -91,9 +91,14 @@ func create(clicontext *cli.Context) error { logrus.Debugf("container %s is running", res.Name) logrus.Debugf("add entry %s to SSH config.", res.Name) + hostname, err := c.GetSSHHostname() + if err != nil { + return errors.Wrap(err, "failed to get the ssh hostname") + } + eo := sshconfig.EntryOptions{ Name: res.Name, - IFace: localhost, + IFace: hostname, Port: res.SSHPort, PrivateKeyPath: clicontext.Path("private-key"), EnableHostKeyCheck: false, @@ -111,6 +116,8 @@ func create(clicontext *cli.Context) error { opt.Port = res.SSHPort opt.AgentForwarding = false opt.User = res.Name + opt.Server = hostname + sshClient, err := ssh.NewClient(opt) if err != nil { return errors.Wrap(err, "failed to create the ssh client") diff --git a/pkg/app/up.go b/pkg/app/up.go index dd9f3dfd2..8b222eb79 100644 --- a/pkg/app/up.go +++ b/pkg/app/up.go @@ -30,10 +30,6 @@ import ( "github.com/tensorchord/envd/pkg/types" ) -const ( - localhost = "127.0.0.1" -) - var CommandUp = &cli.Command{ Name: "up", Category: CategoryBasic, @@ -120,9 +116,13 @@ var CommandUp = &cli.Command{ } func up(clicontext *cli.Context) error { + c, err := home.GetManager().ContextGetCurrent() + if err != nil { + return errors.Wrap(err, "failed to get the current context") + } buildOpt, err := ParseBuildOpt(clicontext) if err != nil { - return err + return errors.Wrap(err, "failed to parse the build options") } ctr := filepath.Base(buildOpt.BuildContextDir) @@ -197,9 +197,14 @@ func up(clicontext *cli.Context) error { logrus.Debugf("container %s is running", res.Name) logrus.Debugf("add entry %s to SSH config.", ctr) + hostname, err := c.GetSSHHostname() + if err != nil { + return errors.Wrap(err, "failed to get the ssh hostname") + } + eo := sshconfig.EntryOptions{ Name: ctr, - IFace: localhost, + IFace: hostname, Port: res.SSHPort, PrivateKeyPath: clicontext.Path("private-key"), EnableHostKeyCheck: false, @@ -218,6 +223,8 @@ func up(clicontext *cli.Context) error { if err != nil { return errors.Wrap(err, "failed to create the ssh client") } + opt.Server = hostname + if err := sshClient.Attach(); err != nil { return errors.Wrap(err, "failed to attach to the container") } diff --git a/pkg/envd/factory.go b/pkg/envd/factory.go index 56c932093..7540fed44 100644 --- a/pkg/envd/factory.go +++ b/pkg/envd/factory.go @@ -36,12 +36,24 @@ func New(ctx context.Context, opt Options) (Engine, error) { if opt.Context.Runner == types.RunnerTypeEnvdServer { ac, err := home.GetManager().AuthGetCurrent() if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to get the auth information") } - cli, err := envdclient.NewClientWithOpts(envdclient.FromEnv) + // Get the runner host. + opts := []envdclient.Opt{ + envdclient.WithTLSClientConfigFromEnv(), + } + if opt.Context.RunnerAddress != nil { + opts = append(opts, + envdclient.WithHost(*opt.Context.RunnerAddress)) + } else { + opts = append(opts, + envdclient.WithHostFromEnv()) + } + + cli, err := envdclient.NewClientWithOpts(opts...) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to create the envd-server client") } return &envdServerEngine{ Client: cli, @@ -51,7 +63,7 @@ func New(ctx context.Context, opt Options) (Engine, error) { cli, err := client.NewClientWithOpts( client.FromEnv, client.WithAPIVersionNegotiation()) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to create the docker client") } return &dockerEngine{ Client: cli, diff --git a/pkg/ssh/ssh.go b/pkg/ssh/ssh.go index b1976517d..c0c2e4eed 100644 --- a/pkg/ssh/ssh.go +++ b/pkg/ssh/ssh.go @@ -57,7 +57,6 @@ type Options struct { func DefaultOptions() Options { return Options{ - Server: "localhost", User: "envd", Auth: true, PrivateKeyPwd: "", diff --git a/pkg/types/envd.go b/pkg/types/envd.go index 022917dbd..8e194efdb 100644 --- a/pkg/types/envd.go +++ b/pkg/types/envd.go @@ -22,6 +22,7 @@ import ( "github.com/docker/docker/api/types" "github.com/moby/buildkit/util/system" + "github.com/tensorchord/envd/pkg/util/netutil" "github.com/tensorchord/envd/pkg/version" ) @@ -278,3 +279,16 @@ func parsePyPICommands(lst string) ([]string, error) { err := json.Unmarshal([]byte(lst), &pkgs) return pkgs, err } + +func (c Context) GetSSHHostname() (string, error) { + if c.RunnerAddress == nil { + return "localhost", nil + } + + // TODO(gaocegege): Check ENVD_SERVER_HOST. + hostname, err := netutil.GetHost(*c.RunnerAddress) + if err != nil { + return "", err + } + return hostname, nil +} diff --git a/pkg/util/netutil/netutil.go b/pkg/util/netutil/netutil.go index a1b1da48a..228b1057c 100644 --- a/pkg/util/netutil/netutil.go +++ b/pkg/util/netutil/netutil.go @@ -14,8 +14,13 @@ package netutil -import "net" +import ( + "fmt" + "net" + "net/url" +) +// GetFreePort returns an available port in the host. func GetFreePort() (int, error) { l, err := net.Listen("tcp", ":0") if err != nil { @@ -24,3 +29,16 @@ func GetFreePort() (int, error) { defer l.Close() return l.Addr().(*net.TCPAddr).Port, nil } + +// GetHost get the IP address from the address. +func GetHost(addr string) (string, error) { + if u, err := url.Parse(addr); err != nil { + return "", err + } else { + h := u.Hostname() + if h == "" { + return "", fmt.Errorf("failed to get the hostname from %s", addr) + } + return h, nil + } +} diff --git a/pkg/util/netutil/netutil_test.go b/pkg/util/netutil/netutil_test.go index e6f93d9be..5a30a2c39 100644 --- a/pkg/util/netutil/netutil_test.go +++ b/pkg/util/netutil/netutil_test.go @@ -25,3 +25,49 @@ func TestGetFreePort(t *testing.T) { assert.NoError(t, err) assert.NotEqual(t, port, 0) } + +func TestGetHost(t *testing.T) { + tcs := []struct { + host string + expected string + err bool + }{ + { + host: "https://localhost:8080", + expected: "localhost", + err: false, + }, + { + host: "localhost:8080", + expected: "", + err: true, + }, + { + host: "http://localhost:8080", + expected: "localhost", + err: false, + }, + { + host: "http://1.1.1.1:8080", + expected: "1.1.1.1", + err: false, + }, + } + for _, tc := range tcs { + host, err := GetHost(tc.host) + if tc.err == true { + if err == nil { + t.Error("expect to get the error, but got nil") + } + continue + } + if tc.err == false { + if err != nil { + t.Errorf("unexpected error %v", err) + } + if host != tc.expected { + t.Errorf("expected %s, got %s", tc.expected, host) + } + } + } +}