diff --git a/cmd/create/create.go b/cmd/create/create.go index 04517bf6..1a3181b0 100644 --- a/cmd/create/create.go +++ b/cmd/create/create.go @@ -19,7 +19,6 @@ package create import ( "fmt" "io" - "log" "os" "path/filepath" @@ -142,7 +141,7 @@ func (m command) run(c *cli.Context, opts *options) error { func runProvision(opts *options) error { var hostUrl string if opts.cfg.Spec.Provider == v1alpha1.ProviderAWS { - for _, p := range opts.cfg.Status.Properties { + for _, p := range opts.cache.Status.Properties { if p.Name == aws.PublicDnsName { hostUrl = p.Value break @@ -163,57 +162,74 @@ func runProvision(opts *options) error { // Download kubeconfig if opts.cfg.Spec.Kubernetes.Install { - if opts.cfg.Spec.Kubernetes.KubeConfig == "" { - // and - if opts.kubeconfig == "" { - fmt.Printf("kubeconfig is not set, use default kubeconfig path: %s\n", filepath.Join(opts.cachePath, "kubeconfig")) - // if kubeconfig is not set, use set to current directory as default - // first get current directory - pwd := os.Getenv("PWD") - opts.kubeconfig = filepath.Join(pwd, "kubeconfig") - } - opts.cfg.Spec.Kubernetes.KubeConfig = opts.kubeconfig + if err = getKubeConfig(opts, p); err != nil { + return fmt.Errorf("failed to get kubeconfig: %v", err) } + } - // Create a session - session, err := p.Client.NewSession() - if err != nil { - fmt.Printf("Failed to create session: %v\n", err) - return err - } - reader, writer := io.Pipe() - session.Stdout = writer - session.Stderr = writer + return nil +} - go func() { - defer writer.Close() - _, err := io.Copy(os.Stdout, reader) - if err != nil { - log.Fatalf("Failed to copy from reader: %v", err) - } - }() - defer session.Close() - // Create a new file on the local system to save the downloaded content - localFile, err := os.Create(opts.kubeconfig) - if err != nil { - return fmt.Errorf("error creating local file: %v", err) +func getKubeConfig(opts *options, p *provisioner.Provisioner) error { + remoteFilePath := "/home/ubuntu/.kube/config" + if opts.cfg.Spec.Kubernetes.KubeConfig == "" { + // and + if opts.kubeconfig == "" { + fmt.Printf("kubeconfig is not set, use default kubeconfig path: %s\n", filepath.Join(opts.cachePath, "kubeconfig")) + // if kubeconfig is not set, use set to current directory as default + // first get current directory + pwd := os.Getenv("PWD") + opts.kubeconfig = filepath.Join(pwd, "kubeconfig") + } else { + opts.cfg.Spec.Kubernetes.KubeConfig = opts.kubeconfig } - defer localFile.Close() + } - // Set up pipes for stdin, stdout, and stderr - session.Stdout = localFile - session.Stderr = os.Stderr + // Create a session + session, err := p.Client.NewSession() + if err != nil { + fmt.Printf("Failed to create session: %v\n", err) + return err + } + defer session.Close() - // Run the SCP command to download the remote file - remoteFilePath := "/home/ubuntu/.kube/config" - err = session.Run("scp -f " + remoteFilePath) - if err != nil { - return fmt.Errorf("error running SCP command: %v", err) - } + // Set up a pipe to receive the remote file content + remoteFile, err := session.StdoutPipe() + if err != nil { + fmt.Printf("Error obtaining remote file pipe: %v\n", err) + return err + } - fmt.Println("KubeConfig downloaded successfully.") + // Start the remote command to read the file content + err = session.Start(fmt.Sprintf("/usr/bin/cat %s", remoteFilePath)) + if err != nil { + fmt.Printf("Error starting remote command: %v\n", err) + return err } + // Create a new file on the local system to save the downloaded content + localFile, err := os.Create(opts.kubeconfig) + if err != nil { + return fmt.Errorf("error creating local file: %v", err) + } + defer localFile.Close() + + // Copy the remote file content to the local file + _, err = io.Copy(localFile, remoteFile) + if err != nil { + fmt.Printf("Error copying file content: %v\n", err) + return err + } + + // Wait for the remote command to finish + err = session.Wait() + if err != nil { + fmt.Printf("Error waiting for remote command: %v\n", err) + return err + } + + fmt.Printf("Kubeconfig saved to %s\n", opts.kubeconfig) + return nil } diff --git a/cmd/delete/delete.go b/cmd/delete/delete.go index 697c7308..d131e37d 100644 --- a/cmd/delete/delete.go +++ b/cmd/delete/delete.go @@ -81,6 +81,11 @@ func (m command) build() *cli.Command { return err } + if opts.cfg.Spec.Provider != v1alpha1.ProviderAWS { + fmt.Printf("Only AWS provider is supported\n") + return err + } + // read hostUrl from cache if opts.cachePath == "" { opts.cachePath = filepath.Join(os.Getenv("HOME"), ".cache", "holodeck") diff --git a/pkg/provider/aws/dryrun.go b/pkg/provider/aws/dryrun.go index 990acc4a..35561641 100644 --- a/pkg/provider/aws/dryrun.go +++ b/pkg/provider/aws/dryrun.go @@ -21,26 +21,27 @@ import ( "fmt" "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" ) -func (a *Client) getInstanceTypes() ([]string, error) { +func (a *Client) getInstanceTypes() ([]types.InstanceType, error) { // Use the DescribeInstanceTypes API to get a list of supported instance types in the current region resp, err := a.ec2.DescribeInstanceTypes(context.TODO(), &ec2.DescribeInstanceTypesInput{}) if err != nil { return nil, err } - instanceTypes := make([]string, 0) + instanceTypes := []types.InstanceType{} for _, it := range resp.InstanceTypes { - instanceTypes = append(instanceTypes, string(it.InstanceType)) + instanceTypes = append(instanceTypes, it.InstanceType) } return instanceTypes, nil } -func (a *Client) isInstanceTypeSupported(desiredType string, supportedTypes []string) bool { +func (a *Client) isInstanceTypeSupported(desiredType string, supportedTypes []types.InstanceType) bool { for _, t := range supportedTypes { - if t == desiredType { + if t == types.InstanceType(a.Spec.Instance.Type) { return true } } @@ -55,7 +56,7 @@ func (a *Client) DryRun() error { } if !a.isInstanceTypeSupported(string(a.Spec.Instance.Type), instanceTypes) { - return fmt.Errorf("instance type %s is not supported in the current region", string(a.Spec.Instance.Type)) + return fmt.Errorf("instance type %s is not supported in the current region %s", string(a.Spec.Instance.Type), a.Spec.Instance.Region) } return nil diff --git a/pkg/provisioner/provisioner.go b/pkg/provisioner/provisioner.go index 2d4a81e6..2614eefb 100644 --- a/pkg/provisioner/provisioner.go +++ b/pkg/provisioner/provisioner.go @@ -119,9 +119,9 @@ func connectOrDie(keyPath, hostUrl string) (*ssh.Client, error) { }, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } - fmt.Printf("Connecting to %s\n", hostUrl) + connectionFailed := false - for i := 0; i < 10; i++ { + for i := 0; i < 20; i++ { client, err = ssh.Dial("tcp", hostUrl+":22", sshConfig) if err == nil { return client, nil // Connection succeeded, return the client.