Skip to content

Commit

Permalink
feat: Auth ssh with key (#205)
Browse files Browse the repository at this point in the history
* auth ssh with key

* address comment

* use error

* remove create directory

* Generate ssh key when init

* address comment

* rename package

* test CI

* try fix ci

* fix typo

Co-authored-by: Jinjing.Zhou <allenzhou@tensorchord.ai>
  • Loading branch information
VoVAllen and VoVAllen authored May 28, 2022
1 parent e0984bb commit 40c7d06
Show file tree
Hide file tree
Showing 15 changed files with 291 additions and 47 deletions.
14 changes: 7 additions & 7 deletions cmd/envd-ssh/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ import (
"github.com/sirupsen/logrus"
cli "github.com/urfave/cli/v2"

"github.com/tensorchord/envd/pkg/config"
"github.com/tensorchord/envd/pkg/remote/sshd"
"github.com/tensorchord/envd/pkg/ssh"
"github.com/tensorchord/envd/pkg/version"
)

const (
authorizedKeysPath = "/var/envd/remote/authorized_keys"
envPort = "envd_SSH_PORT"
flagDebug = "debug"
flagAuthKey = "authorized-keys"
flagNoAuth = "no-auth"
envPort = "envd_SSH_PORT"
flagDebug = "debug"
flagAuthKey = "authorized-keys"
flagNoAuth = "no-auth"
)

func main() {
Expand All @@ -55,8 +55,8 @@ func main() {
},
&cli.StringFlag{
Name: flagAuthKey,
Usage: "path to authorized keys file, defaults to " + authorizedKeysPath,
Value: authorizedKeysPath,
Usage: "path to authorized keys file, defaults to " + config.ContainerauthorizedKeysPath,
Value: config.ContainerauthorizedKeysPath,
Aliases: []string{"a"},
},
&cli.BoolFlag{
Expand Down
9 changes: 8 additions & 1 deletion cmd/envd/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"github.com/tensorchord/envd/pkg/builder"
"github.com/tensorchord/envd/pkg/flag"
"github.com/tensorchord/envd/pkg/home"
sshconfig "github.com/tensorchord/envd/pkg/ssh/config"
"github.com/tensorchord/envd/pkg/util/fileutil"
)

Expand All @@ -50,6 +51,12 @@ var CommandBuild = &cli.Command{
Aliases: []string{"p"},
Value: ".",
},
&cli.PathFlag{
Name: "public-key",
Usage: "Path to the public key",
Aliases: []string{"pubk"},
Value: sshconfig.GetPublicKey(),
},
},

Action: build,
Expand Down Expand Up @@ -92,5 +99,5 @@ func build(clicontext *cli.Context) error {
if err != nil {
return errors.Wrap(err, "failed to create the builder")
}
return builder.Build(clicontext.Context)
return builder.Build(clicontext.Path("public-key"), clicontext.Context)
}
4 changes: 2 additions & 2 deletions cmd/envd/destroy.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
cli "github.com/urfave/cli/v2"

"github.com/tensorchord/envd/pkg/docker"
"github.com/tensorchord/envd/pkg/ssh"
sshconfig "github.com/tensorchord/envd/pkg/ssh/config"
"github.com/tensorchord/envd/pkg/util/fileutil"
)

Expand Down Expand Up @@ -59,7 +59,7 @@ func destroy(clicontext *cli.Context) error {
return errors.Wrapf(err, "failed to destroy the environment: %s", ctr)
}

if err = ssh.RemoveEntry(ctr); err != nil {
if err = sshconfig.RemoveEntry(ctr); err != nil {
logrus.Infof("failed to remove entry %s from your SSH config file: %s", ctr, err)
return errors.Wrap(err, "failed to remove entry from your SSH config file")
}
Expand Down
26 changes: 17 additions & 9 deletions cmd/envd/up.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/tensorchord/envd/pkg/home"
"github.com/tensorchord/envd/pkg/lang/ir"
"github.com/tensorchord/envd/pkg/ssh"
sshconfig "github.com/tensorchord/envd/pkg/ssh/config"
"github.com/tensorchord/envd/pkg/util/fileutil"
)

Expand Down Expand Up @@ -60,16 +61,22 @@ var CommandUp = &cli.Command{
Aliases: []string{"f"},
Value: "build.envd",
},
&cli.BoolFlag{
Name: "auth",
Usage: "Enable authentication for ssh",
Value: false,
},
// &cli.BoolFlag{
// Name: "auth",
// Usage: "Enable authentication for ssh",
// Value: false,
// },
&cli.PathFlag{
Name: "private-key",
Usage: "Path to the private key",
Aliases: []string{"k"},
Value: "~/.ssh/id_rsa",
Value: sshconfig.GetPrivateKey(),
},
&cli.PathFlag{
Name: "public-key",
Usage: "Path to the public key",
Aliases: []string{"pubk"},
Value: sshconfig.GetPublicKey(),
},
&cli.DurationFlag{
Name: "timeout",
Expand Down Expand Up @@ -129,7 +136,7 @@ func up(clicontext *cli.Context) error {
return errors.Wrap(err, "failed to create the builder")
}

if err := builder.Build(clicontext.Context); err != nil {
if err := builder.Build(clicontext.Path("public-key"), clicontext.Context); err != nil {
return err
}
gpu := builder.GPUEnabled()
Expand All @@ -138,6 +145,7 @@ func up(clicontext *cli.Context) error {
if err != nil {
return err
}

containerID, containerIP, err := dockerClient.StartEnvd(clicontext.Context,
tag, ctr, buildContext, gpu, *ir.DefaultGraph, clicontext.Duration("timeout"),
clicontext.StringSlice("volume"))
Expand All @@ -147,14 +155,14 @@ func up(clicontext *cli.Context) error {
logrus.Debugf("container %s is running", containerID)

logrus.Debugf("Add entry %s to SSH config. at %s", buildContext, containerIP)
if err = ssh.AddEntry(ctr, containerIP, ssh.DefaultSSHPort); err != nil {
if err = sshconfig.AddEntry(ctr, containerIP, ssh.DefaultSSHPort, clicontext.Path("private-key")); err != nil {
logrus.Infof("failed to add entry %s to your SSH config file: %s", ctr, err)
return errors.Wrap(err, "failed to add entry to your SSH config file")
}

if !detach {
sshClient, err := ssh.NewClient(
containerIP, "root", ssh.DefaultSSHPort, clicontext.Bool("auth"), clicontext.Path("private-key"), "")
containerIP, "envd", ssh.DefaultSSHPort, true, clicontext.Path("private-key"), "")
if err != nil {
return err
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/builder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import (
)

type Builder interface {
Build(ctx context.Context) error
Build(pub string, ctx context.Context) error
GPUEnabled() bool
}

Expand Down Expand Up @@ -81,8 +81,8 @@ func (b generalBuilder) GPUEnabled() bool {
return ir.GPUEnabled()
}

func (b generalBuilder) Build(ctx context.Context) error {
def, err := b.compile(ctx)
func (b generalBuilder) Build(pub string, ctx context.Context) error {
def, err := b.compile(pub, ctx)
if err != nil {
return errors.Wrap(err, "failed to compile")
}
Expand Down Expand Up @@ -110,11 +110,11 @@ func (b generalBuilder) interpret() error {
return nil
}

func (b generalBuilder) compile(ctx context.Context) (*llb.Definition, error) {
func (b generalBuilder) compile(pub string, ctx context.Context) (*llb.Definition, error) {
if err := b.interpret(); err != nil {
return nil, errors.Wrap(err, "failed to interpret")
}
def, err := ir.Compile(ctx, fileutil.Base(b.buildContextDir))
def, err := ir.Compile(ctx, fileutil.Base(b.buildContextDir), pub)
if err != nil {
return nil, errors.Wrap(err, "failed to compile build.envd")
}
Expand Down
7 changes: 5 additions & 2 deletions pkg/builder/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/tensorchord/envd/pkg/progress/compileui"
compileuimock "github.com/tensorchord/envd/pkg/progress/compileui/mock"
"github.com/tensorchord/envd/pkg/progress/progresswriter"
sshconfig "github.com/tensorchord/envd/pkg/ssh/config"
)

var _ = Describe("Builder", func() {
Expand Down Expand Up @@ -90,21 +91,23 @@ var _ = Describe("Builder", func() {
b.Interpreter.(*mockstarlark.MockInterpreter).EXPECT().ExecFile(
gomock.Eq(configFilePath),
).Return(nil, expected)
err := b.Build(context.TODO())
pub := sshconfig.GetPublicKey()
err := b.Build(pub, context.TODO())
Expect(err).To(HaveOccurred())
})
})

When("failed to interpret manifest", func() {
It("should get an error", func() {
expected := errors.New("failed to interpret manifest")
pub := sshconfig.GetPublicKey()
b.Interpreter.(*mockstarlark.MockInterpreter).EXPECT().ExecFile(
gomock.Eq(configFilePath),
).Return(nil, nil)
b.Interpreter.(*mockstarlark.MockInterpreter).EXPECT().ExecFile(
gomock.Eq(b.manifestFilePath),
).Return(nil, expected)
err := b.Build(context.TODO())
err := b.Build(pub, context.TODO())
Expect(err).To(HaveOccurred())
})
})
Expand Down
33 changes: 33 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright 2022 The envd Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package config

import (
"path/filepath"

"github.com/adrg/xdg"
)

type UpState string

const (
PrivateKeyFile = "id_rsa_envd"
PublicKeyFile = "id_rsa_envd.pub"
ContainerauthorizedKeysPath = "/var/envd/authorized_keys"
envdFolderName = ".envd"
)

func GetEnvdHome() string {
return filepath.Join(xdg.ConfigHome, envdFolderName)
}
9 changes: 6 additions & 3 deletions pkg/docker/entrypoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,24 @@ import (
"fmt"
"strings"

"github.com/tensorchord/envd/pkg/config"
"github.com/tensorchord/envd/pkg/editor/jupyter"
"github.com/tensorchord/envd/pkg/lang/ir"
)

const (
template = `set -e
/var/envd/bin/envd-ssh --no-auth &
/var/envd/bin/envd-ssh --authorized-keys %s &
%s
wait -n`
)

func entrypointSH(g ir.Graph, workingDir string) string {
if g.JupyterConfig != nil {
cmds := jupyter.GenerateCommand(g, workingDir)
return fmt.Sprintf(template, strings.Join(cmds, " "))
return fmt.Sprintf(template,
config.ContainerauthorizedKeysPath, strings.Join(cmds, " "))
}
return fmt.Sprintf(template, "")
return fmt.Sprintf(template,
config.ContainerauthorizedKeysPath, "")
}
15 changes: 13 additions & 2 deletions pkg/home/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/adrg/xdg"
"github.com/cockroachdb/errors"
"github.com/sirupsen/logrus"
sshconfig "github.com/tensorchord/envd/pkg/ssh/config"
"github.com/tensorchord/envd/pkg/util/fileutil"
)

Expand Down Expand Up @@ -85,7 +86,7 @@ func (m generalManager) Cached(key string) bool {
func (m *generalManager) dumpCacheStatus() error {
file, err := os.Create(m.cacheStatusFile)
if err != nil {
return errors.Wrap(err, "failed to open cache status file")
return errors.Wrap(err, "failed to create cache status file")
}
defer file.Close()

Expand Down Expand Up @@ -120,9 +121,14 @@ func (m *generalManager) init() error {
if err != nil {
if os.IsNotExist(err) {
logrus.WithField("filename", m.cacheStatusFile).Debug("Creating file")
if _, err := os.Create(m.cacheStatusFile); err != nil {
file, err := os.Create(m.cacheStatusFile)
if err != nil {
return errors.Wrap(err, "failed to create file")
}
err = file.Close()
if err != nil {
return errors.Wrap(err, "failed to close file")
}
if err := m.dumpCacheStatus(); err != nil {
return errors.Wrap(err, "failed to dump cache status")
}
Expand All @@ -131,6 +137,11 @@ func (m *generalManager) init() error {
}
}

// Generate SSH keys when init
if err := sshconfig.GenerateKeys(); err != nil {
return errors.Wrap(err, "failed to generate ssh key")
}

file, err := os.Open(m.cacheStatusFile)
if err != nil {
return errors.Wrap(err, "failed to open cache status file")
Expand Down
2 changes: 1 addition & 1 deletion pkg/lang/frontend/starlark/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func NewInterpreter() Interpreter {
}

func (s generalInterpreter) ExecFile(filename string) (interface{}, error) {
logrus.WithField("filename", filename).Debug("inperprete the file")
logrus.WithField("filename", filename).Debug("interprete the file")
var src interface{}
globals, err := starlark.ExecFile(s.Thread, filename, src, nil)
if err != nil {
Expand Down
9 changes: 7 additions & 2 deletions pkg/lang/ir/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,14 @@ func GPUEnabled() bool {
return DefaultGraph.CUDA != nil
}

func Compile(ctx context.Context, cachePrefix string) (*llb.Definition, error) {
func Compile(ctx context.Context, cachePrefix string, pub string) (*llb.Definition, error) {
w, err := compileui.New(ctx, os.Stdout, "auto")
if err != nil {
return nil, errors.Wrap(err, "failed to create compileui")
}
DefaultGraph.Writer = w
DefaultGraph.CachePrefix = cachePrefix
DefaultGraph.PublicKeyPath = pub
state, err := DefaultGraph.Compile()
if err != nil {
return nil, err
Expand Down Expand Up @@ -109,7 +110,11 @@ func (g Graph) Compile() (llb.State, error) {
diffShellStage := llb.Diff(builtinSystemStage, shellStage, llb.WithCustomName("install shell"))
pypiStage := llb.Diff(builtinSystemStage, g.compilePyPIPackages(builtinSystemStage), llb.WithCustomName("install PyPI packages"))
systemStage := llb.Diff(builtinSystemStage, g.compileSystemPackages(builtinSystemStage), llb.WithCustomName("install system packages"))
sshStage := g.copyEnvdSSHServer()
sshStage, err := g.copyEnvdSSHServerWithKey()

if err != nil {
return llb.State{}, errors.Wrap(err, "failed to copy SSH key")
}

vscodeStage, err := g.compileVSCode()
if err != nil {
Expand Down
Loading

0 comments on commit 40c7d06

Please sign in to comment.