Skip to content

Commit

Permalink
Fix kubeconfig download
Browse files Browse the repository at this point in the history
Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
  • Loading branch information
ArangoGutierrez committed Jan 9, 2024
1 parent 979f07d commit 1ff2836
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 52 deletions.
104 changes: 60 additions & 44 deletions cmd/create/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package create
import (
"fmt"
"io"
"log"
"os"
"path/filepath"

Expand Down Expand Up @@ -142,7 +141,7 @@ func (m command) run(c *cli.Context, opts *options) error {
func runProvision(opts *options) error {
var hostUrl string
if opts.cfg.Spec.Provider == v1alpha1.ProviderAWS {
for _, p := range opts.cfg.Status.Properties {
for _, p := range opts.cache.Status.Properties {
if p.Name == aws.PublicDnsName {
hostUrl = p.Value
break
Expand All @@ -163,57 +162,74 @@ func runProvision(opts *options) error {

// Download kubeconfig
if opts.cfg.Spec.Kubernetes.Install {
if opts.cfg.Spec.Kubernetes.KubeConfig == "" {
// and
if opts.kubeconfig == "" {
fmt.Printf("kubeconfig is not set, use default kubeconfig path: %s\n", filepath.Join(opts.cachePath, "kubeconfig"))
// if kubeconfig is not set, use set to current directory as default
// first get current directory
pwd := os.Getenv("PWD")
opts.kubeconfig = filepath.Join(pwd, "kubeconfig")
}
opts.cfg.Spec.Kubernetes.KubeConfig = opts.kubeconfig
if err = getKubeConfig(opts, p); err != nil {
return fmt.Errorf("failed to get kubeconfig: %v", err)
}
}

// Create a session
session, err := p.Client.NewSession()
if err != nil {
fmt.Printf("Failed to create session: %v\n", err)
return err
}
reader, writer := io.Pipe()
session.Stdout = writer
session.Stderr = writer
return nil
}

go func() {
defer writer.Close()
_, err := io.Copy(os.Stdout, reader)
if err != nil {
log.Fatalf("Failed to copy from reader: %v", err)
}
}()
defer session.Close()
// Create a new file on the local system to save the downloaded content
localFile, err := os.Create(opts.kubeconfig)
if err != nil {
return fmt.Errorf("error creating local file: %v", err)
func getKubeConfig(opts *options, p *provisioner.Provisioner) error {
remoteFilePath := "/home/ubuntu/.kube/config"
if opts.cfg.Spec.Kubernetes.KubeConfig == "" {
// and
if opts.kubeconfig == "" {
fmt.Printf("kubeconfig is not set, use default kubeconfig path: %s\n", filepath.Join(opts.cachePath, "kubeconfig"))
// if kubeconfig is not set, use set to current directory as default
// first get current directory
pwd := os.Getenv("PWD")
opts.kubeconfig = filepath.Join(pwd, "kubeconfig")
} else {
opts.cfg.Spec.Kubernetes.KubeConfig = opts.kubeconfig
}
defer localFile.Close()
}

// Set up pipes for stdin, stdout, and stderr
session.Stdout = localFile
session.Stderr = os.Stderr
// Create a session
session, err := p.Client.NewSession()
if err != nil {
fmt.Printf("Failed to create session: %v\n", err)
return err
}
defer session.Close()

// Run the SCP command to download the remote file
remoteFilePath := "/home/ubuntu/.kube/config"
err = session.Run("scp -f " + remoteFilePath)
if err != nil {
return fmt.Errorf("error running SCP command: %v", err)
}
// Set up a pipe to receive the remote file content
remoteFile, err := session.StdoutPipe()
if err != nil {
fmt.Printf("Error obtaining remote file pipe: %v\n", err)
return err
}

fmt.Println("KubeConfig downloaded successfully.")
// Start the remote command to read the file content
err = session.Start(fmt.Sprintf("/usr/bin/cat %s", remoteFilePath))
if err != nil {
fmt.Printf("Error starting remote command: %v\n", err)
return err
}

// Create a new file on the local system to save the downloaded content
localFile, err := os.Create(opts.kubeconfig)
if err != nil {
return fmt.Errorf("error creating local file: %v", err)
}
defer localFile.Close()

// Copy the remote file content to the local file
_, err = io.Copy(localFile, remoteFile)
if err != nil {
fmt.Printf("Error copying file content: %v\n", err)
return err
}

// Wait for the remote command to finish
err = session.Wait()
if err != nil {
fmt.Printf("Error waiting for remote command: %v\n", err)
return err
}

fmt.Printf("Kubeconfig saved to %s\n", opts.kubeconfig)

return nil
}

Expand Down
5 changes: 5 additions & 0 deletions cmd/delete/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ func (m command) build() *cli.Command {
return err
}

if opts.cfg.Spec.Provider != v1alpha1.ProviderAWS {
fmt.Printf("Only AWS provider is supported\n")
return err
}

// read hostUrl from cache
if opts.cachePath == "" {
opts.cachePath = filepath.Join(os.Getenv("HOME"), ".cache", "holodeck")
Expand Down
13 changes: 7 additions & 6 deletions pkg/provider/aws/dryrun.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,27 @@ import (
"fmt"

"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
)

func (a *Client) getInstanceTypes() ([]string, error) {
func (a *Client) getInstanceTypes() ([]types.InstanceType, error) {
// Use the DescribeInstanceTypes API to get a list of supported instance types in the current region
resp, err := a.ec2.DescribeInstanceTypes(context.TODO(), &ec2.DescribeInstanceTypesInput{})
if err != nil {
return nil, err
}

instanceTypes := make([]string, 0)
instanceTypes := []types.InstanceType{}
for _, it := range resp.InstanceTypes {
instanceTypes = append(instanceTypes, string(it.InstanceType))
instanceTypes = append(instanceTypes, it.InstanceType)
}

return instanceTypes, nil
}

func (a *Client) isInstanceTypeSupported(desiredType string, supportedTypes []string) bool {
func (a *Client) isInstanceTypeSupported(desiredType string, supportedTypes []types.InstanceType) bool {
for _, t := range supportedTypes {
if t == desiredType {
if t == types.InstanceType(a.Spec.Instance.Type) {
return true
}
}
Expand All @@ -55,7 +56,7 @@ func (a *Client) DryRun() error {
}

if !a.isInstanceTypeSupported(string(a.Spec.Instance.Type), instanceTypes) {
return fmt.Errorf("instance type %s is not supported in the current region", string(a.Spec.Instance.Type))
return fmt.Errorf("instance type %s is not supported in the current region %s", string(a.Spec.Instance.Type), a.Spec.Instance.Region)
}

return nil
Expand Down
4 changes: 2 additions & 2 deletions pkg/provisioner/provisioner.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ func connectOrDie(keyPath, hostUrl string) (*ssh.Client, error) {
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
fmt.Printf("Connecting to %s\n", hostUrl)

connectionFailed := false
for i := 0; i < 10; i++ {
for i := 0; i < 20; i++ {
client, err = ssh.Dial("tcp", hostUrl+":22", sshConfig)
if err == nil {
return client, nil // Connection succeeded, return the client.
Expand Down

0 comments on commit 1ff2836

Please sign in to comment.