Skip to content

Commit

Permalink
feat(WSL): Add ssh config entry to Windows ssh config if using WSL (#604
Browse files Browse the repository at this point in the history
)

* initial

Signed-off-by: Jinjing.Zhou <allenzhou@tensorchord.ai>

* add key to windows entry

Signed-off-by: Jinjing.Zhou <allenzhou@tensorchord.ai>

* lint

Signed-off-by: Jinjing.Zhou <allenzhou@tensorchord.ai>

* fix

Signed-off-by: Jinjing.Zhou <allenzhou@tensorchord.ai>
  • Loading branch information
VoVAllen authored Jul 14, 2022
1 parent 27c4cde commit 178b8da
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pkg/app/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func bootstrap(clicontext *cli.Context) error {
var exists bool
var newPrivateKeyName string
for ok := true; ok; ok = exists {
newPrivateKeyName = filepath.Join(filepath.Dir(sshconfig.GetPrivateKey()), fmt.Sprintf("%s.pk", namesgenerator.GetRandomName(0)))
newPrivateKeyName = filepath.Join(filepath.Dir(sshconfig.GetPrivateKey()), fmt.Sprintf("envd_%s.pk", namesgenerator.GetRandomName(0)))
exists, err = fileutil.FileExists(newPrivateKeyName)
if err != nil {
return err
Expand Down
86 changes: 83 additions & 3 deletions pkg/ssh/config/ssh_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import (
"strings"

"github.com/sirupsen/logrus"

"github.com/tensorchord/envd/pkg/util/osutil"
)

type (
Expand Down Expand Up @@ -302,7 +304,28 @@ func buildHostname(name string) string {

// AddEntry adds an entry to the user's sshconfig
func AddEntry(name, iface string, port int, privateKeyPath string) error {
return add(getSSHConfigPath(), buildHostname(name), iface, port, privateKeyPath)
err := add(getSSHConfigPath(), buildHostname(name), iface, port, privateKeyPath)
if err != nil {
return err
}
if osutil.IsWsl() {
logrus.Debug("Try adding entry to WSL's ssh-agent")
winSshConfig, err := osutil.GetWslHostSshConfig()
if err != nil {
return err
}
winKeyPath, err := osutil.CopyToWinEnvdHome(privateKeyPath, 0600)
if err != nil {
return err
}
// Add the entry to the WSL host SSH config
logrus.Debugf("Adding entry to WSL's ssh-agent: %s", winSshConfig)
err = add(winSshConfig, buildHostname(name), iface, port, winKeyPath)
if err != nil {
return err
}
}
return nil
}

func ReplaceKeyManagedByEnvd(oldKey string, newKey string) error {
Expand All @@ -329,7 +352,49 @@ func ReplaceKeyManagedByEnvd(oldKey string, newKey string) error {
if err != nil {
return err
}
return save(cfg, getSSHConfigPath())

err = save(cfg, getSSHConfigPath())
if err != nil {
return err
}

if osutil.IsWsl() {
winSshConfig, err := osutil.GetWslHostSshConfig()
if err != nil {
return err
}
cfg, err := getConfig(winSshConfig)
if err != nil {
return err
}
winNewKey, err := osutil.CopyToWinEnvdHome(newKey, 0600)
if err != nil {
return err
}
winOldKey, err := osutil.CopyToWinEnvdHome(oldKey, 0600)
if err != nil {
return err
}
logrus.Infof("Rewrite WSL ssh keys old: %s, new: %s", winOldKey, winNewKey)
for ih, h := range cfg.hosts {
for _, hn := range h.hostnames {
logrus.Info(h.hostnames)
if strings.HasSuffix(hn, ".envd") {
for ip, p := range h.params {
if p.keyword == identityFile && strings.Trim(p.args[0], "\"") == winOldKey {
logrus.Debug("Change key")
cfg.hosts[ih].params[ip].args[0] = winNewKey
}
}
}
}
}
err = save(cfg, winSshConfig)
if err != nil {
return err
}
}
return nil
}

func add(path, name, iface string, port int, privateKeyPath string) error {
Expand Down Expand Up @@ -361,7 +426,22 @@ func add(path, name, iface string, port int, privateKeyPath string) error {

// RemoveEntry removes the entry to the user's sshconfig if found
func RemoveEntry(name string) error {
return remove(getSSHConfigPath(), buildHostname(name))
err := remove(getSSHConfigPath(), buildHostname(name))
if err != nil {
return err
}
if osutil.IsWsl() {
logrus.Debug("Try removing entry from WSL's ssh-agent")
winSshConfig, err := osutil.GetWslHostSshConfig()
if err != nil {
return err
}
err = remove(winSshConfig, buildHostname(name))
if err != nil {
return err
}
}
return nil
}

// GetPort returns the corresponding SSH port for the dev env
Expand Down
157 changes: 157 additions & 0 deletions pkg/util/osutil/wsl.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
// Copyright 2022 The envd Authors
// Copyright 2022 mateors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package osutil

import (
"fmt"
"io"
"net"
"os"
"os/exec"
"path"
"path/filepath"
"strings"

"github.com/cockroachdb/errors"
"github.com/sirupsen/logrus"
)

func IsWsl() bool {
// Return false if meet error
cmd := exec.Command("cat", "/proc/version")
output, err := cmd.Output()
if err != nil {
logrus.Debugf("Error when check whether sys is WSL: %v", err)
return false
}

return strings.Contains(strings.ToLower(string(output)), "microsoft")
}

func GetWslHostSshConfig() (string, error) {
userCmd := exec.Command("wslvar", "USERPROFILE")
userOutput, err := userCmd.Output()
if err != nil {
return "", err
}

cmd := exec.Command("wslpath", string(userOutput))
output, err := cmd.Output()
if err != nil {
return "", err
}
outputPath := path.Join(strings.Trim(string(output), "\n"), ".ssh", "config")
logrus.Debugf("wsl sshconfig path: %s", outputPath)
return outputPath, nil
}

func GetWslIp() (string, error) {
ip, err := getInterfaceIpv4Addr("eth0")
if err != nil {
return "", err
}
return ip, nil
}

func GetWindowsEnvdConfigHome() (string, error) {

userCmd := exec.Command("wslvar", "LOCALAPPDATA")
userOutput, err := userCmd.Output()
if err != nil {
return "", err
}

cmd := exec.Command("wslpath", string(userOutput))
output, err := cmd.Output()
if err != nil {
return "", err
}
envdDir := filepath.Join(strings.Trim(string(output), "\n"), "envd")
if err := os.MkdirAll(envdDir, 0755); err != nil {
return "", err
}
return envdDir, nil
}

// from: https://gist.github.com/schwarzeni/f25031a3123f895ff3785970921e962c
func getInterfaceIpv4Addr(interfaceName string) (addr string, err error) {
var (
ief *net.Interface
addrs []net.Addr
ipv4Addr net.IP
)
if ief, err = net.InterfaceByName(interfaceName); err != nil { // get interface
return
}
if addrs, err = ief.Addrs(); err != nil { // get addresses
return
}
for _, addr := range addrs { // get ipv4 address
if ipv4Addr = addr.(*net.IPNet).IP.To4(); ipv4Addr != nil {
break
}
}
if ipv4Addr == nil {
return "", errors.New(fmt.Sprintf("interface %s don't have an ipv4 address\n", interfaceName))
}
return ipv4Addr.String(), nil
}

func CopyToWinEnvdHome(src string, permission os.FileMode) (string, error) {
// Return dst path in windows format
winhome, err := GetWindowsEnvdConfigHome()
if err != nil {
return "", err
}
filename := filepath.Base(src)
dst := filepath.Join(winhome, filename)
err = copy(src, dst, permission)
if err != nil {
return "", err
}

envdDirWinCmd := exec.Command("wslpath", "-w", dst)
winDir, err := envdDirWinCmd.Output()

if err != nil {
return "", err
}
return strings.Trim(string(winDir), "\n"), nil
}

func copy(src, dst string, permission os.FileMode) error {
in, err := os.Open(src)
if err != nil {
return err
}
defer in.Close()

out, err := os.Create(dst)
if err != nil {
return err
}
defer out.Close()
err = os.Chmod(dst, permission)
if err != nil {
return err
}

_, err = io.Copy(out, in)
if err != nil {
return err
}
return out.Close()
}

0 comments on commit 178b8da

Please sign in to comment.