Skip to content

Commit

Permalink
Multi-user VM support - start and SSH into playgrounds under non-root…
Browse files Browse the repository at this point in the history
… users
  • Loading branch information
iximiuz committed Aug 21, 2024
1 parent 2c4a53c commit 6d8655d
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 31 deletions.
2 changes: 1 addition & 1 deletion cmd/auth/logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func runLogout(ctx context.Context, cli labcli.CLI) error {
}

if err := ssh.RemoveIdentity(cli.Config().SSHDir); err != nil {
slog.Warn("Failed to remove SSH identity file: %v", err)
slog.Warn("Failed to remove SSH identity file", "error", err.Error())
}

cli.Config().SessionID = ""
Expand Down
24 changes: 17 additions & 7 deletions cmd/playground/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
type startOptions struct {
playground string
machine string
user string

open bool

Expand Down Expand Up @@ -55,6 +56,20 @@ func newStartCommand(cli labcli.CLI) *cobra.Command {

flags := cmd.Flags()

flags.StringVar(
&opts.machine,
"machine",
"",
`SSH into the machine with the given name (requires --ssh flag, default to the first machine)`,
)
flags.StringVarP(
&opts.user,
"user",
"u",
"",
`SSH user (default: the machine's default login user)`,
)

flags.BoolVar(
&opts.open,
"open",
Expand All @@ -73,12 +88,6 @@ func newStartCommand(cli labcli.CLI) *cobra.Command {
false,
`Open the playground in the IDE (only VSCode is supported at the moment)`,
)
flags.StringVar(
&opts.machine,
"machine",
"",
`SSH into the machine with the given name (requires --ssh flag, default to the first machine)`,
)
flags.BoolVarP(
&opts.quiet,
"quiet",
Expand Down Expand Up @@ -112,6 +121,7 @@ func runStartPlayground(ctx context.Context, cli labcli.CLI, opts *startOptions)
return sshproxy.RunSSHProxy(ctx, cli, &sshproxy.Options{
PlayID: play.ID,
Machine: opts.machine,
User: opts.user,
IDE: true,
})
}
Expand All @@ -127,7 +137,7 @@ func runStartPlayground(ctx context.Context, cli labcli.CLI, opts *startOptions)

cli.PrintAux("SSH-ing into %s machine...\n", opts.machine)

return ssh.RunSSHSession(ctx, cli, play.ID, opts.machine, nil)
return ssh.RunSSHSession(ctx, cli, play.ID, opts.machine, opts.user, nil)
}

cli.PrintOut("%s\n", play.ID)
Expand Down
1 change: 0 additions & 1 deletion cmd/portforward/portforward.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ func runPortForward(ctx context.Context, cli labcli.CLI, opts *options) error {
PlayID: opts.playID,
Machine: opts.machine,
PlaysDir: cli.Config().PlaysDir,
SSHDir: cli.Config().SSHDir,
})
if err != nil {
return fmt.Errorf("couldn't start tunnel: %w", err)
Expand Down
27 changes: 24 additions & 3 deletions cmd/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ const example = ` # SSH into the first machine in the playground
type options struct {
playID string
machine string
user string

command []string
}
Expand Down Expand Up @@ -57,6 +58,13 @@ func NewCommand(cli labcli.CLI) *cobra.Command {
"",
`Target machine (default: the first machine in the playground)`,
)
flags.StringVarP(
&opts.user,
"user",
"u",
"",
`SSH user (default: the machine's default login user)`,
)

return cmd
}
Expand All @@ -75,20 +83,33 @@ func runSSHSession(ctx context.Context, cli labcli.CLI, opts *options) error {
}
}

return RunSSHSession(ctx, cli, opts.playID, opts.machine, opts.command)
if opts.user == "" {
if u := p.GetMachine(opts.machine).DefaultUser(); u != nil {
opts.user = u.Name
} else {
opts.user = "root"
}
}
if !p.GetMachine(opts.machine).HasUser(opts.user) {
return fmt.Errorf("user %q not found in the machine %q", opts.user, opts.machine)
}

return RunSSHSession(ctx, cli, opts.playID, opts.machine, opts.user, opts.command)
}

func RunSSHSession(
ctx context.Context,
cli labcli.CLI,
playID string,
machine string,
user string,
command []string,
) error {
tunnel, err := portforward.StartTunnel(ctx, cli.Client(), portforward.TunnelOptions{
PlayID: playID,
Machine: machine,
PlaysDir: cli.Config().PlaysDir,
SSHUser: user,
SSHDir: cli.Config().SSHDir,
})
if err != nil {
Expand Down Expand Up @@ -120,7 +141,7 @@ func RunSSHSession(
return

case err := <-errCh:
slog.Debug("Tunnel error: %v", err)
slog.Debug("Tunnel borked", "error", err.Error())
}
}
}()
Expand All @@ -138,7 +159,7 @@ func RunSSHSession(
}
defer conn.Close()

sess, err := ssh.NewSession(conn, "root", cli.Config().SSHDir)
sess, err := ssh.NewSession(conn, user, cli.Config().SSHDir)
if err != nil {
return fmt.Errorf("couldn't create SSH session: %w", err)
}
Expand Down
42 changes: 35 additions & 7 deletions cmd/sshproxy/sshproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
type Options struct {
PlayID string
Machine string
User string
Address string

IDE bool
Expand Down Expand Up @@ -49,6 +50,12 @@ func NewCommand(cli labcli.CLI) *cobra.Command {
"",
`Target machine (default: the first machine in the playground)`,
)
flags.StringVar(
&opts.User,
"user",
"",
`Login user (default: the machine's default login user)`,
)
flags.StringVar(
&opts.Address,
"address",
Expand Down Expand Up @@ -79,10 +86,22 @@ func RunSSHProxy(ctx context.Context, cli labcli.CLI, opts *Options) error {
}
}

if opts.User == "" {
if u := p.GetMachine(opts.Machine).DefaultUser(); u != nil {
opts.User = u.Name
} else {
opts.User = "root"
}
}
if !p.GetMachine(opts.Machine).HasUser(opts.User) {
return fmt.Errorf("user %q not found in the machine %q", opts.User, opts.Machine)
}

tunnel, err := portforward.StartTunnel(ctx, cli.Client(), portforward.TunnelOptions{
PlayID: opts.PlayID,
Machine: opts.Machine,
PlaysDir: cli.Config().PlaysDir,
SSHUser: opts.User,
SSHDir: cli.Config().SSHDir,
})
if err != nil {
Expand Down Expand Up @@ -116,29 +135,30 @@ func RunSSHProxy(ctx context.Context, cli labcli.CLI, opts *Options) error {
return

case err := <-errCh:
slog.Debug("Tunnel error: %v", err)
slog.Debug("Tunnel borked", "error", err.Error())
}
}
}()

if !opts.IDE {
cli.PrintOut("SSH proxy is running on %s\n", localPort)
cli.PrintOut(
"\n# Connect from the terminal:\nssh -i %s/%s ssh://root@%s:%s\n",
cli.Config().SSHDir, ssh.IdentityFile, localHost, localPort,
"\n# Connect from the terminal:\nssh -i %s/%s ssh://%s@%s:%s\n",
cli.Config().SSHDir, ssh.IdentityFile, opts.User, localHost, localPort,
)

cli.PrintOut("\n# Or add the following to your ~/.ssh/config:\n")
cli.PrintOut("Host %s\n", opts.PlayID+"-"+opts.Machine)
cli.PrintOut(" HostName %s\n", localHost)
cli.PrintOut(" Port %s\n", localPort)
cli.PrintOut(" User root\n")
cli.PrintOut(" User %s\n", opts.User)
cli.PrintOut(" IdentityFile %s/%s\n", cli.Config().SSHDir, ssh.IdentityFile)
cli.PrintOut(" StrictHostKeyChecking no\n")
cli.PrintOut(" UserKnownHostsFile /dev/null\n\n")

cli.PrintOut("# To access the playground in Visual Studio Code:\n")
cli.PrintOut("code --folder-uri vscode-remote://ssh-remote+root@%s:%s/root\n\n", localHost, localPort)
cli.PrintOut("code --folder-uri vscode-remote://ssh-remote+%s@%s:%s/%s\n\n",
opts.User, localHost, localPort, userHomeDir(opts.User))

cli.PrintOut("\nPress Ctrl+C to stop\n")
} else {
Expand All @@ -151,12 +171,13 @@ func RunSSHProxy(ctx context.Context, cli labcli.CLI, opts *Options) error {
"-o", "IdentitiesOnly=yes",
"-o", "PreferredAuthentications=publickey",
"-i", fmt.Sprintf("%s/%s", cli.Config().SSHDir, ssh.IdentityFile),
fmt.Sprintf("ssh://root@%s:%s", localHost, localPort),
fmt.Sprintf("ssh://%s@%s:%s", opts.User, localHost, localPort),
)
cmd.Run()

cmd = exec.Command("code",
"--folder-uri", fmt.Sprintf("vscode-remote://ssh-remote+root@%s:%s/root", localHost, localPort),
"--folder-uri", fmt.Sprintf("vscode-remote://ssh-remote+%s@%s:%s/%s",
opts.User, localHost, localPort, userHomeDir(opts.User)),
)
if err := cmd.Run(); err != nil {
return fmt.Errorf("couldn't open the IDE: %w", err)
Expand Down Expand Up @@ -192,3 +213,10 @@ func hostStr(address string) string {

return strings.Split(address, ":")[0]
}

func userHomeDir(user string) string {
if user == "root" {
return "/root"
}
return fmt.Sprintf("/home/%s", user)
}
40 changes: 35 additions & 5 deletions internal/api/plays.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,41 @@ func (p *Play) GetMachine(name string) *Machine {
return nil
}

type MachineUser struct {
Name string `json:"name"`
Default bool `json:"default"`
}

type Machine struct {
Name string `json:"name"`
CPUCount int `json:"cpuCount"`
RAMSize string `json:"ramSize"`
DrivePerf string `json:"drivePerf"`
NetworkPerf string `json:"networkPerf"`
Name string `json:"name"`
Users []MachineUser `json:"users"`
CPUCount int `json:"cpuCount"`
RAMSize string `json:"ramSize"`
DrivePerf string `json:"drivePerf"`
NetworkPerf string `json:"networkPerf"`
}

func (m *Machine) DefaultUser() *MachineUser {
for _, u := range m.Users {
if u.Default {
return &u
}
}
return nil
}

func (m *Machine) HasUser(name string) bool {
if name == "root" {
// Everyone has root
return true
}

for _, u := range m.Users {
if u.Name == name {
return true
}
}
return false
}

type CreatePlayRequest struct {
Expand Down Expand Up @@ -116,6 +145,7 @@ type StartTunnelRequest struct {
Port int `json:"port"`
Access PortAccess `json:"access"`
GenerateLoginURL bool `json:"generateLoginUrl"`
SSHUser string `json:"sshUser"`
SSHPubKey string `json:"sshPubKey"`
}

Expand Down
12 changes: 9 additions & 3 deletions internal/portforward/tunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type TunnelOptions struct {
PlayID string
Machine string
PlaysDir string
SSHUser string
SSHDir string
}

Expand All @@ -35,7 +36,11 @@ type Tunnel struct {
}

func StartTunnel(ctx context.Context, client *api.Client, opts TunnelOptions) (*Tunnel, error) {
tunnelFile := filepath.Join(opts.PlaysDir, opts.PlayID+"-"+opts.Machine, "tunnel.json")
uniq := opts.PlayID + "-" + opts.Machine
if opts.SSHUser != "" {
uniq += "-" + opts.SSHUser
}
tunnelFile := filepath.Join(opts.PlaysDir, uniq, "tunnel.json")
if t, err := loadTunnel(tunnelFile); err == nil {
return t, nil
}
Expand All @@ -53,9 +58,10 @@ func StartTunnel(ctx context.Context, client *api.Client, opts TunnelOptions) (*

resp, err := client.StartTunnel(ctx, opts.PlayID, api.StartTunnelRequest{
Machine: opts.Machine,
SSHPubKey: sshPubKey,
Access: api.PortAccessPrivate,
GenerateLoginURL: true,
SSHUser: opts.SSHUser,
SSHPubKey: sshPubKey,
})
if err != nil {
return nil, fmt.Errorf("client.StartTunnel(): %w", err)
Expand All @@ -75,7 +81,7 @@ func StartTunnel(ctx context.Context, client *api.Client, opts TunnelOptions) (*
}

if err := saveTunnel(tunnelFile, t); err != nil {
slog.Warn("Couldn't save tunnel info to file: %v", err)
slog.Warn("Couldn't save tunnel info to file", "error", err.Error())
}

return t, nil
Expand Down
6 changes: 3 additions & 3 deletions internal/ssh/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func (s *Session) Run(ctx context.Context, streams labcli.Streams, cmd string) e

if streams.InputStream().IsTerminal() {
if err := streams.InputStream().SetRawTerminal(); err != nil {
slog.Warn("Could not enable raw terminal mode", err)
slog.Warn("Could not enable raw terminal mode", "error", err.Error())
} else {
defer streams.InputStream().RestoreTerminal()

Expand All @@ -90,7 +90,7 @@ func (s *Session) Run(ctx context.Context, streams labcli.Streams, cmd string) e

go func() {
if err := watchWindowSize(ctx, streams.OutputStream(), sess); err != nil {
slog.Debug("Error watching window size", err)
slog.Debug("Error watching window size", "error", err.Error())
}
}()
}
Expand Down Expand Up @@ -121,7 +121,7 @@ func (s *Session) Run(ctx context.Context, streams labcli.Streams, cmd string) e
err = sess.Shell()
if err == nil {
if err := sess.Wait(); err != nil {
slog.Debug("Error waiting for shell", err)
slog.Debug("Waiting for shell failed", "error", err.Error())
}
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func main() {
}

// Hopefully, only usage errors.
slog.Debug("Exit error: %s", err)
slog.Debug("Exit error: " + err.Error())
os.Exit(1)
}
}
Expand Down

0 comments on commit 6d8655d

Please sign in to comment.