From 1950833d77bfa8c0cb39af398719aed1592c9c5c Mon Sep 17 00:00:00 2001 From: Carlos Eduardo Arango Gutierrez Date: Wed, 24 Jan 2024 21:19:03 +0100 Subject: [PATCH] Multiple enhancements Signed-off-by: Carlos Eduardo Arango Gutierrez --- api/holodeck/v1alpha1/types.go | 2 + cmd/create/create.go | 63 ++-- cmd/delete/delete.go | 33 +- cmd/dryrun/dryrun.go | 76 ++++- cmd/main.go | 22 +- examples/v1alpha1_environment.yaml | 5 +- internal/logger/logger.go | 154 +++++++++ pkg/provider/aws/aws.go | 5 +- pkg/provider/aws/create.go | 190 ++++++++---- pkg/provider/aws/delete.go | 18 +- pkg/provider/aws/dryrun.go | 20 +- pkg/provisioner/dependency.go | 192 ++++++++++++ pkg/provisioner/dryrun.go | 24 +- pkg/provisioner/provisioner.go | 292 ++++++------------ pkg/provisioner/templates/common.go | 11 + .../templates/container-toolkit.go | 35 ++- pkg/provisioner/templates/containerd.go | 41 ++- pkg/provisioner/templates/crio.go | 29 +- pkg/provisioner/templates/docker.go | 36 ++- pkg/provisioner/templates/kubernetes.go | 80 +++-- pkg/provisioner/templates/nv-driver.go | 26 +- 21 files changed, 966 insertions(+), 388 deletions(-) create mode 100644 internal/logger/logger.go create mode 100644 pkg/provisioner/dependency.go diff --git a/api/holodeck/v1alpha1/types.go b/api/holodeck/v1alpha1/types.go index 808683cc..30fc583e 100644 --- a/api/holodeck/v1alpha1/types.go +++ b/api/holodeck/v1alpha1/types.go @@ -157,6 +157,8 @@ const ( ContainerRuntimeContainerd ContainerRuntimeName = "containerd" // ContainerRuntimeCrio means the container runtime is Crio ContainerRuntimeCrio ContainerRuntimeName = "crio" + // ContainerRuntimeNone means the container runtime is not defined + ContainerRuntimeNone ContainerRuntimeName = "" ) type Kubernetes struct { diff --git a/cmd/create/create.go b/cmd/create/create.go index ca28b8a7..68309993 100644 --- a/cmd/create/create.go +++ b/cmd/create/create.go @@ -23,11 +23,11 @@ import ( "path/filepath" "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" + "github.com/NVIDIA/holodeck/internal/logger" "github.com/NVIDIA/holodeck/pkg/jyaml" "github.com/NVIDIA/holodeck/pkg/provider/aws" "github.com/NVIDIA/holodeck/pkg/provisioner" - "github.com/sirupsen/logrus" cli "github.com/urfave/cli/v2" ) @@ -44,13 +44,13 @@ type options struct { } type command struct { - logger *logrus.Logger + log *logger.FunLogger } // NewCommand constructs the create command with the specified logger -func NewCommand(logger *logrus.Logger) *cli.Command { +func NewCommand(log *logger.FunLogger) *cli.Command { c := command{ - logger: logger, + log: log, } return c.build() } @@ -93,8 +93,7 @@ func (m command) build() *cli.Command { var err error opts.cfg, err = jyaml.UnmarshalFromFile[v1alpha1.Environment](opts.envFile) if err != nil { - fmt.Printf("failed to read config file: %v\n", err) - return err + return fmt.Errorf("error reading config file: %s", err) } // set cache path @@ -103,6 +102,11 @@ func (m command) build() *cli.Command { } opts.cachefile = filepath.Join(opts.cachePath, opts.cfg.Name+".yaml") + // if no containerruntime is specified, default to none + if opts.cfg.Spec.ContainerRuntime.Name == "" { + opts.cfg.Spec.ContainerRuntime.Name = v1alpha1.ContainerRuntimeNone + } + return nil }, Action: func(c *cli.Context) error { @@ -115,7 +119,7 @@ func (m command) build() *cli.Command { func (m command) run(c *cli.Context, opts *options) error { if opts.cfg.Spec.Provider == v1alpha1.ProviderAWS { - err := createAWS(opts) + err := createAWS(m.log, opts) if err != nil { return fmt.Errorf("failed to create AWS infra: %v", err) } @@ -125,11 +129,12 @@ func (m command) run(c *cli.Context, opts *options) error { return fmt.Errorf("failed to read cache file: %v", err) } } else if opts.cfg.Spec.Provider == v1alpha1.ProviderSSH { + m.log.Info("SSH infrastructure \u2601") opts.cache = opts.cfg } if opts.provision { - err := runProvision(opts) + err := runProvision(m.log, opts) if err != nil { return fmt.Errorf("failed to provision: %v", err) } @@ -138,8 +143,11 @@ func (m command) run(c *cli.Context, opts *options) error { return nil } -func runProvision(opts *options) error { +func runProvision(log *logger.FunLogger, opts *options) error { var hostUrl string + + log.Info("Provisioning \u2699") + if opts.cfg.Spec.Provider == v1alpha1.ProviderAWS { for _, p := range opts.cache.Status.Properties { if p.Name == aws.PublicDnsName { @@ -151,7 +159,7 @@ func runProvision(opts *options) error { hostUrl = opts.cfg.Spec.Instance.HostUrl } - p, err := provisioner.New(opts.cfg.Spec.Auth.PrivateKey, hostUrl) + p, err := provisioner.New(log, opts.cfg.Spec.Auth.PrivateKey, hostUrl) if err != nil { return err } @@ -161,8 +169,13 @@ func runProvision(opts *options) error { } // Download kubeconfig - if opts.cfg.Spec.Kubernetes.Install { - if err = getKubeConfig(opts, p); err != nil { + if opts.cfg.Spec.Kubernetes.Install && opts.cfg.Spec.Kubernetes.KubeConfig != "" { + if opts.cfg.Spec.Kubernetes.KubernetesInstaller == "microk8s" || opts.cfg.Spec.Kubernetes.KubernetesInstaller == "kind" { + log.Warning("kubeconfig is not supported for %s, skipping kubeconfig download", opts.cfg.Spec.Kubernetes.KubernetesInstaller) + return nil + } + + if err = getKubeConfig(log, opts, p); err != nil { return fmt.Errorf("failed to get kubeconfig: %v", err) } } @@ -171,12 +184,12 @@ func runProvision(opts *options) error { } // getKubeConfig downloads the kubeconfig file from the remote host -func getKubeConfig(opts *options, p *provisioner.Provisioner) error { +func getKubeConfig(log *logger.FunLogger, opts *options, p *provisioner.Provisioner) error { remoteFilePath := "/home/ubuntu/.kube/config" if opts.cfg.Spec.Kubernetes.KubeConfig == "" { // and if opts.kubeconfig == "" { - fmt.Printf("kubeconfig is not set, use default kubeconfig path: %s\n", filepath.Join(opts.cachePath, "kubeconfig")) + log.Warning("kubeconfig is not set, use default kubeconfig path: %s\n", filepath.Join(opts.cachePath, "kubeconfig")) // if kubeconfig is not set, use set to current directory as default // first get current directory pwd := os.Getenv("PWD") @@ -189,23 +202,20 @@ func getKubeConfig(opts *options, p *provisioner.Provisioner) error { // Create a session session, err := p.Client.NewSession() if err != nil { - fmt.Printf("Failed to create session: %v\n", err) - return err + return fmt.Errorf("error creating session: %v", err) } defer session.Close() // Set up a pipe to receive the remote file content remoteFile, err := session.StdoutPipe() if err != nil { - fmt.Printf("Error obtaining remote file pipe: %v\n", err) - return err + return fmt.Errorf("error creating remote file pipe: %v", err) } // Start the remote command to read the file content err = session.Start(fmt.Sprintf("/usr/bin/cat %s", remoteFilePath)) if err != nil { - fmt.Printf("Error starting remote command: %v\n", err) - return err + return fmt.Errorf("error starting remote command: %v", err) } // Create a new file on the local system to save the downloaded content @@ -218,24 +228,23 @@ func getKubeConfig(opts *options, p *provisioner.Provisioner) error { // Copy the remote file content to the local file _, err = io.Copy(localFile, remoteFile) if err != nil { - fmt.Printf("Error copying file content: %v\n", err) - return err + return fmt.Errorf("error copying remote file to local: %v", err) } // Wait for the remote command to finish err = session.Wait() if err != nil { - fmt.Printf("Error waiting for remote command: %v\n", err) - return err + return fmt.Errorf("error waiting for remote command: %v", err) } - fmt.Printf("Kubeconfig saved to %s\n", opts.kubeconfig) + log.Info(fmt.Sprintf("Kubeconfig saved to %s\n", opts.kubeconfig)) return nil } -func createAWS(opts *options) error { - client, err := aws.New(opts.cfg, opts.cachefile) +func createAWS(log *logger.FunLogger, opts *options) error { + log.Info("Creating AWS infrastructure \u2601") + client, err := aws.New(log, opts.cfg, opts.cachefile) if err != nil { return err } diff --git a/cmd/delete/delete.go b/cmd/delete/delete.go index 9cbcdabf..03c06637 100644 --- a/cmd/delete/delete.go +++ b/cmd/delete/delete.go @@ -22,10 +22,10 @@ import ( "path/filepath" "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" + "github.com/NVIDIA/holodeck/internal/logger" "github.com/NVIDIA/holodeck/pkg/jyaml" "github.com/NVIDIA/holodeck/pkg/provider/aws" - "github.com/sirupsen/logrus" cli "github.com/urfave/cli/v2" ) @@ -40,13 +40,13 @@ type options struct { } type command struct { - logger *logrus.Logger + log *logger.FunLogger } // NewCommand constructs the delete command with the specified logger -func NewCommand(logger *logrus.Logger) *cli.Command { +func NewCommand(log *logger.FunLogger) *cli.Command { c := command{ - logger: logger, + log: log, } return c.build() } @@ -77,13 +77,11 @@ func (m command) build() *cli.Command { var err error opts.cfg, err = jyaml.UnmarshalFromFile[v1alpha1.Environment](opts.envFile) if err != nil { - fmt.Printf("Error reading config file: %s\n", err) - return err + return fmt.Errorf("error reading config file: %s", err) } if opts.cfg.Spec.Provider != v1alpha1.ProviderAWS { - fmt.Printf("Only AWS provider is supported\n") - return err + return fmt.Errorf("provider %s not supported", opts.cfg.Spec.Provider) } // read hostUrl from cache @@ -113,15 +111,15 @@ func (m command) run(c *cli.Context, opts *options) error { cfg, err := jyaml.UnmarshalFromFile[v1alpha1.Environment](opts.envFile) if err != nil { - fmt.Printf("Error reading config file: %s\n", err) - os.Exit(1) + m.log.Error(err) + m.log.Exit(1) } cachefile := filepath.Join(opts.cachePath, cfg.Name+".yaml") - client, err := aws.New(cfg, cachefile) + client, err := aws.New(m.log, cfg, cachefile) if err != nil { - fmt.Printf("Error creating client: %s\n", err) - os.Exit(1) + m.log.Error(err) + m.log.Exit(1) } // check if cache exists @@ -131,13 +129,12 @@ func (m command) run(c *cli.Context, opts *options) error { os.Exit(1) } - err = client.Delete() - if err != nil { - fmt.Printf("Error deleting environment: %s\n", err) - os.Exit(1) + if err := client.Delete(); err != nil { + m.log.Error(err) + m.log.Exit(1) } - fmt.Printf("Successfully deleted environment %s\n", cfg.Name) + m.log.Info("Successfully deleted environment %s\n", cfg.Name) return nil } diff --git a/cmd/dryrun/dryrun.go b/cmd/dryrun/dryrun.go index 9796b11d..cce1d883 100644 --- a/cmd/dryrun/dryrun.go +++ b/cmd/dryrun/dryrun.go @@ -18,13 +18,16 @@ package dryrun import ( "fmt" + "os" + "time" "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" + "github.com/NVIDIA/holodeck/internal/logger" "github.com/NVIDIA/holodeck/pkg/jyaml" "github.com/NVIDIA/holodeck/pkg/provider/aws" "github.com/NVIDIA/holodeck/pkg/provisioner" + "golang.org/x/crypto/ssh" - "github.com/sirupsen/logrus" cli "github.com/urfave/cli/v2" ) @@ -35,13 +38,13 @@ type options struct { } type command struct { - logger *logrus.Logger + log *logger.FunLogger } -// NewCommand constructs a dryrun command with the specified logger -func NewCommand(logger *logrus.Logger) *cli.Command { +// NewCommand constructs the DryRun command with the specified logger +func NewCommand(log *logger.FunLogger) *cli.Command { c := command{ - logger: logger, + log: log, } return c.build() } @@ -66,8 +69,11 @@ func (m command) build() *cli.Command { var err error opts.cfg, err = jyaml.UnmarshalFromFile[v1alpha1.Environment](opts.envFile) if err != nil { - fmt.Printf("failed to read config file: %v\n", err) - return err + return fmt.Errorf("failed to read config file %s: %v", opts.envFile, err) + } + + if opts.cfg.Spec.ContainerRuntime.Name == "" && opts.cfg.Spec.Kubernetes.Install { + m.log.Warning("No container runtime specified, will default defaulting to containerd") } return nil @@ -81,17 +87,17 @@ func (m command) build() *cli.Command { } func (m command) run(c *cli.Context, opts *options) error { + m.log.Info("Dryrun environment %s \U0001f50d", opts.cfg.Name) + // Check Provider switch opts.cfg.Spec.Provider { case v1alpha1.ProviderAWS: - err := validateAWS(opts) + err := validateAWS(m.log, opts) if err != nil { return err } case v1alpha1.ProviderSSH: - // Creating a new provisioner will validate the private key and hostUrl - _, err := provisioner.New(opts.cfg.Spec.Auth.PrivateKey, opts.cfg.Spec.Instance.HostUrl) - if err != nil { + if err := connectOrDie(opts.cfg.Spec.Auth.PrivateKey, opts.cfg.Spec.Instance.HostUrl); err != nil { return err } default: @@ -99,18 +105,17 @@ func (m command) run(c *cli.Context, opts *options) error { } // Check Provisioner - err := provisioner.Dryrun(opts.cfg) - if err != nil { + if err := provisioner.Dryrun(m.log, opts.cfg); err != nil { return err } - fmt.Printf("Dryrun succeeded\n") + m.log.Check("Dryrun succeeded\n") return nil } -func validateAWS(opts *options) error { - client, err := aws.New(opts.cfg, opts.envFile) +func validateAWS(log *logger.FunLogger, opts *options) error { + client, err := aws.New(log, opts.cfg, opts.envFile) if err != nil { return err } @@ -121,3 +126,42 @@ func validateAWS(opts *options) error { return nil } + +// createSshClient creates a ssh client, and retries if it fails to connect +func connectOrDie(keyPath, hostUrl string) error { + var err error + key, err := os.ReadFile(keyPath) + if err != nil { + return fmt.Errorf("failed to read key file: %v", err) + } + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + return fmt.Errorf("failed to parse private key: %v", err) + } + sshConfig := &ssh.ClientConfig{ + User: "ubuntu", + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(signer), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + connectionFailed := false + for i := 0; i < 20; i++ { + client, err := ssh.Dial("tcp", hostUrl+":22", sshConfig) + if err == nil { + client.Close() + return nil // Connection succeeded, + } + connectionFailed = true + // Sleep for a brief moment before retrying. + // You can adjust the duration based on your requirements. + time.Sleep(1 * time.Second) + } + + if connectionFailed { + return fmt.Errorf("failed to connect to %s", hostUrl) + } + + return nil +} diff --git a/cmd/main.go b/cmd/main.go index cced0cb9..6bb3f567 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -22,8 +22,8 @@ import ( "github.com/NVIDIA/holodeck/cmd/create" "github.com/NVIDIA/holodeck/cmd/delete" "github.com/NVIDIA/holodeck/cmd/dryrun" + "github.com/NVIDIA/holodeck/internal/logger" - log "github.com/sirupsen/logrus" cli "github.com/urfave/cli/v2" ) @@ -32,7 +32,7 @@ const ( ProgramName = "holodeck" ) -var logger = log.New() +var log = logger.NewLogger() type config struct { Debug bool @@ -58,26 +58,16 @@ func main() { }, } - // Set log-level for all subcommands - c.Before = func(c *cli.Context) error { - logLevel := log.InfoLevel - if config.Debug { - logLevel = log.DebugLevel - } - logger.SetLevel(logLevel) - return nil - } - // Define the subcommands c.Commands = []*cli.Command{ - create.NewCommand(logger), - delete.NewCommand(logger), - dryrun.NewCommand(logger), + create.NewCommand(log), + delete.NewCommand(log), + dryrun.NewCommand(log), } err := c.Run(os.Args) if err != nil { - log.Errorf("%v", err) + log.Error(err) log.Exit(1) } } diff --git a/examples/v1alpha1_environment.yaml b/examples/v1alpha1_environment.yaml index de1702cb..b230c28b 100644 --- a/examples/v1alpha1_environment.yaml +++ b/examples/v1alpha1_environment.yaml @@ -18,8 +18,7 @@ spec: imageId: ami-0fe8bec493a81c7da containerRuntime: install: true - name: containerd kubernetes: install: true - installer: kubeadm - version: v1.28.5 + installer: kind + version: 1.29 diff --git a/internal/logger/logger.go b/internal/logger/logger.go new file mode 100644 index 00000000..58a67908 --- /dev/null +++ b/internal/logger/logger.go @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package logger + +import ( + "fmt" + "io" + "os" + "sync" + "time" +) + +const ( + // ANSI escape code to reset color + reset = "\033[0m" + // ANSI escape code for green color + green = "\033[32m" + // ANSI escape code for yellow text + yellowText = "\033[33m" + // ANSI escape code for red text + redText = "\033[31m" + // Unicode code point for the checkmark + checkmark = "\u2714" + // Unicode character for the red X emoji + redXEmoji = "\u274C" + // Unicode character for the warning sign + warningSign = "\u26A0" +) + +// NewLogger creates a new instance of FunLogger. +func NewLogger() *FunLogger { + return &FunLogger{ + Out: os.Stderr, + Done: make(chan struct{}), + Fail: make(chan struct{}), + Wg: &sync.WaitGroup{}, + ExitFunc: os.Exit, + } +} + +// Printer interface defines methods for logging info, warning, and error messages. +type Logger interface { + Info(format string, a ...any) + Check(format string, a ...any) + Warning(format string, a ...any) + Error(err error) + Loading(loadingMessage string) +} + +// FunFonts implements the Logger interface using emojis for messages. +type FunLogger struct { + // The logs are `io.Copy`'d to this in a mutex. It's common to set this to a + // file, or leave it default which is `os.Stderr`. You can also set this to + // something more adventurous, such as logging to Kafka. + Out io.Writer + // Function to exit the application, defaults to `os.Exit()` + ExitFunc exitFunc + // Done is a channel that can be used to stop the loading animation. + Done chan struct{} + // Fail is a channel that can be used to stop the loading animation and print a failure message. + Fail chan struct{} + // Wg is a WaitGroup that can be used to wait for the loading animation to finish. + Wg *sync.WaitGroup +} + +// Info prints an information message with no emoji. +func (l *FunLogger) Info(format string, a ...any) { + if format[len(format)-1] != '\n' { + format += "\n" + } + + fmt.Fprintf(l.Out, format, a...) +} + +// Info prints an information message with a check emoji. +func (l *FunLogger) Check(format string, a ...any) { + message := fmt.Sprintf(format, a...) + printMessage(green, checkmark, message) +} + +// Warning prints a warning message with a warning emoji. +func (l *FunLogger) Warning(format string, a ...any) { + message := fmt.Sprintf(format, a...) + printMessage(yellowText, warningSign, message) +} + +// Error prints an error message with an X emoji. +func (l *FunLogger) Error(err error) { + printMessage(redText, redXEmoji, err.Error()) +} + +// printMessage is a helper function to print the message with the specified emoji. +func printMessage(color, emoji, message string) { + fmt.Printf("%s%s%s\t%s\n", color, emoji, reset, message) +} + +func (l *FunLogger) Loading(format string, a ...any) { + defer l.Wg.Done() + message := fmt.Sprintf(format, a...) + // if message ends with a newline, remove it + if message[len(message)-1] == '\n' { + message = message[:len(message)-1] + } + + ticker := time.After(330 * time.Millisecond) + i := 0 + + spinners := []string{"|", "/", "-", "\\"} + + for { + select { + case <-l.Done: + fmt.Print("\r\033[2K") + printMessage(green, checkmark, message) + return + case <-l.Fail: + fmt.Print("\r\033[2K") + printMessage(redText, redXEmoji, message) + return + case <-ticker: + i++ + fmt.Printf("\r%s\t%s", spinners[i], message) + if i >= len(spinners)-1 { + i = 0 + } + + ticker = time.After(330 * time.Millisecond) + } + } +} + +func (l *FunLogger) Exit(code int) { + // Stop the loading animation if it's running + close(l.Done) + l.Wg.Wait() + + l.ExitFunc(code) +} + +type exitFunc func(int) diff --git a/pkg/provider/aws/aws.go b/pkg/provider/aws/aws.go index 8438fb68..8b42b845 100644 --- a/pkg/provider/aws/aws.go +++ b/pkg/provider/aws/aws.go @@ -21,6 +21,7 @@ import ( "os" "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" + "github.com/NVIDIA/holodeck/internal/logger" "github.com/NVIDIA/holodeck/pkg/jyaml" "sigs.k8s.io/yaml" @@ -77,9 +78,10 @@ type Client struct { cacheFile string *v1alpha1.Environment + log *logger.FunLogger } -func New(env v1alpha1.Environment, cacheFile string) (*Client, error) { +func New(log *logger.FunLogger, env v1alpha1.Environment, cacheFile string) (*Client, error) { // Create an AWS session and configure the EC2 client cfg, err := config.LoadDefaultConfig(context.TODO(), config.WithRegion(env.Spec.Region)) if err != nil { @@ -99,6 +101,7 @@ func New(env v1alpha1.Environment, cacheFile string) (*Client, error) { r53, cacheFile, &env, + log, } return c, nil diff --git a/pkg/provider/aws/create.go b/pkg/provider/aws/create.go index ddf2b490..bf3ba26d 100644 --- a/pkg/provider/aws/create.go +++ b/pkg/provider/aws/create.go @@ -29,22 +29,53 @@ import ( // Create creates an EC2 instance with proper Network configuration // VPC, Subnet, Internet Gateway, Route Table, Security Group func (a *Client) Create() error { - a.updateProgressingCondition(*a.Environment.DeepCopy(), &AWS{}, "v1alpha1.Creating", "Creating AWS resources") + cache := new(AWS) + defer a.dumpCache(cache) - if cache, err := a.create(); err != nil { - a.updateDegradedCondition(*a.Environment.DeepCopy(), cache, "v1alpha1.Creating", "Error creating AWS resources") - return fmt.Errorf("error creating AWS resources: %v", err) + a.updateProgressingCondition(*a.Environment.DeepCopy(), cache, "v1alpha1.Creating", "Creating AWS resources") + + if err := a.createVPC(cache); err != nil { + a.updateDegradedCondition(*a.Environment.DeepCopy(), cache, "v1alpha1.Creating", "Error creating VPC") + return fmt.Errorf("error creating VPC: %v", err) + } + + if err := a.createSubnet(cache); err != nil { + a.updateDegradedCondition(*a.Environment.DeepCopy(), cache, "v1alpha1.Creating", "Error creating subnet") + return fmt.Errorf("error creating subnet: %v", err) + } + + if err := a.createInternetGateway(cache); err != nil { + a.updateDegradedCondition(*a.Environment.DeepCopy(), cache, "v1alpha1.Creating", "Error creating Internet Gateway") + return fmt.Errorf("error creating Internet Gateway: %v", err) + } + + if err := a.createRouteTable(cache); err != nil { + a.updateDegradedCondition(*a.Environment.DeepCopy(), cache, "v1alpha1.Creating", "Error creating route table") + return fmt.Errorf("error creating route table: %v", err) + } + + if err := a.createSecurityGroup(cache); err != nil { + a.updateDegradedCondition(*a.Environment.DeepCopy(), cache, "v1alpha1.Creating", "Error creating security group") + return fmt.Errorf("error creating security group: %v", err) } + if err := a.createEC2Instance(cache); err != nil { + a.updateDegradedCondition(*a.Environment.DeepCopy(), cache, "v1alpha1.Creating", "Error creating EC2 instance") + return fmt.Errorf("error creating EC2 instance: %v", err) + } + + // Save objects ID's into a cache file + if err := a.updateAvailableCondition(*a.Environment, cache); err != nil { + return fmt.Errorf("error creating cache file: %v", err) + } return nil } -func (a *Client) create() (*AWS, error) { - var cache AWS - defer a.dumpCache(&cache) +// createVPC creates a VPC with CIDR +func (a *Client) createVPC(cache *AWS) error { + a.log.Wg.Add(1) + go a.log.Loading("Creating VPC") - // Define the VPC parameters - fmt.Println("Creating VPC with CIDR") vpcInput := &ec2.CreateVpcInput{ CidrBlock: aws.String("10.0.0.0/16"), AmazonProvidedIpv6CidrBlock: &no, @@ -59,7 +90,8 @@ func (a *Client) create() (*AWS, error) { vpcOutput, err := a.ec2.CreateVpc(context.TODO(), vpcInput) if err != nil { - return &cache, fmt.Errorf("error creating VPC: %v", err) + a.fail() + return fmt.Errorf("error creating VPC: %v", err) } cache.Vpcid = *vpcOutput.Vpc.VpcId @@ -67,16 +99,24 @@ func (a *Client) create() (*AWS, error) { VpcId: vpcOutput.Vpc.VpcId, EnableDnsHostnames: &types.AttributeBooleanValue{Value: &yes}, } - fmt.Printf("Enabling DNS hostnames for VPC %s\n", cache.Vpcid) + _, err = a.ec2.ModifyVpcAttribute(context.Background(), modVcp) if err != nil { - return &cache, fmt.Errorf("error modifying VPC attributes: %v", err) + a.fail() + return fmt.Errorf("error modifying VPC attributes: %v", err) } + a.done() + + return nil +} + +// createSubnet creates a subnet for the VPC +func (a *Client) createSubnet(cache *AWS) error { + a.log.Wg.Add(1) + go a.log.Loading("Creating subnet") - // Create a subnet - fmt.Println("Creating subnet") subnetInput := &ec2.CreateSubnetInput{ - VpcId: vpcOutput.Vpc.VpcId, + VpcId: aws.String(cache.Vpcid), CidrBlock: aws.String("10.0.0.0/24"), TagSpecifications: []types.TagSpecification{ { @@ -87,37 +127,53 @@ func (a *Client) create() (*AWS, error) { } subnetOutput, err := a.ec2.CreateSubnet(context.TODO(), subnetInput) if err != nil { - return &cache, fmt.Errorf("error creating subnet: %v", err) + a.fail() + return fmt.Errorf("error creating subnet: %v", err) } cache.Subnetid = *subnetOutput.Subnet.SubnetId - // Create an Internet Gateway - fmt.Println("Creating Internet Gateway") + a.done() + return nil +} + +// createInternetGateway creates an Internet Gateway and attaches it to the VPC +func (a *Client) createInternetGateway(cache *AWS) error { + a.log.Wg.Add(1) + go a.log.Loading("Creating Internet Gateway") + gwInput := &ec2.CreateInternetGatewayInput{} gwOutput, err := a.ec2.CreateInternetGateway(context.TODO(), gwInput) if err != nil { - return &cache, fmt.Errorf("error creating Internet Gateway: %v", err) + a.fail() + return fmt.Errorf("error creating Internet Gateway: %v", err) } cache.InternetGwid = *gwOutput.InternetGateway.InternetGatewayId // Attach Internet Gateway to the VPC - fmt.Println("Attaching Internet Gateway to the VPC") attachInput := &ec2.AttachInternetGatewayInput{ - VpcId: vpcOutput.Vpc.VpcId, + VpcId: aws.String(cache.Vpcid), InternetGatewayId: gwOutput.InternetGateway.InternetGatewayId, } _, err = a.ec2.AttachInternetGateway(context.TODO(), attachInput) if err != nil { - return &cache, fmt.Errorf("error attaching Internet Gateway: %v", err) + a.fail() + return fmt.Errorf("error attaching Internet Gateway: %v", err) } if len(gwOutput.InternetGateway.Attachments) > 0 { cache.InternetGatewayAttachment = *gwOutput.InternetGateway.Attachments[0].VpcId } - // Create a route table - fmt.Println("Creating route table") + a.done() + return nil +} + +// createRouteTable creates a route table and associates it with the subnet +func (a *Client) createRouteTable(cache *AWS) error { + a.log.Wg.Add(1) + go a.log.Loading("Creating route table") + rtInput := &ec2.CreateRouteTableInput{ - VpcId: vpcOutput.Vpc.VpcId, + VpcId: aws.String(cache.Vpcid), TagSpecifications: []types.TagSpecification{ { ResourceType: types.ResourceTypeRouteTable, @@ -127,38 +183,44 @@ func (a *Client) create() (*AWS, error) { } rtOutput, err := a.ec2.CreateRouteTable(context.TODO(), rtInput) if err != nil { - return &cache, fmt.Errorf("error creating route table: %v", err) + a.fail() + return fmt.Errorf("error creating route table: %v", err) } cache.RouteTable = *rtOutput.RouteTable.RouteTableId // Associate the route table with the subnet - fmt.Println("Associating route table with the subnet") assocInput := &ec2.AssociateRouteTableInput{ RouteTableId: rtOutput.RouteTable.RouteTableId, - SubnetId: subnetOutput.Subnet.SubnetId, + SubnetId: aws.String(cache.Subnetid), } - _, err = a.ec2.AssociateRouteTable(context.Background(), assocInput) - if err != nil { - return &cache, fmt.Errorf("error associating route table: %v", err) + if _, err = a.ec2.AssociateRouteTable(context.Background(), assocInput); err != nil { + a.fail() + return fmt.Errorf("error associating route table: %v", err) } routeInput := &ec2.CreateRouteInput{ RouteTableId: rtOutput.RouteTable.RouteTableId, DestinationCidrBlock: aws.String("0.0.0.0/0"), - GatewayId: gwOutput.InternetGateway.InternetGatewayId, + GatewayId: aws.String(cache.InternetGwid), } - _, err = a.ec2.CreateRoute(context.TODO(), routeInput) - if err != nil { - return &cache, fmt.Errorf("error creating route: %v", err) + if _, err = a.ec2.CreateRoute(context.TODO(), routeInput); err != nil { + return fmt.Errorf("error creating route: %v", err) } - // Create a security group to allow external communication with K8S control - // plane - fmt.Println("Creating security group") + a.done() + return nil +} + +// createSecurityGroup creates a security group to allow external communication +// with K8S control plane and SSH +func (a *Client) createSecurityGroup(cache *AWS) error { + a.log.Wg.Add(1) + go a.log.Loading("Creating security group") + sgInput := &ec2.CreateSecurityGroupInput{ GroupName: &a.ObjectMeta.Name, Description: &description, - VpcId: vpcOutput.Vpc.VpcId, + VpcId: aws.String(cache.Vpcid), TagSpecifications: []types.TagSpecification{ { ResourceType: types.ResourceTypeSecurityGroup, @@ -168,7 +230,8 @@ func (a *Client) create() (*AWS, error) { } sgOutput, err := a.ec2.CreateSecurityGroup(context.TODO(), sgInput) if err != nil { - return &cache, fmt.Errorf("error creating security group: %v", err) + a.fail() + return fmt.Errorf("error creating security group: %v", err) } cache.SecurityGroupid = *sgOutput.GroupId @@ -204,13 +267,20 @@ func (a *Client) create() (*AWS, error) { }, } - _, err = a.ec2.AuthorizeSecurityGroupIngress(context.TODO(), irInput) - if err != nil { - return &cache, fmt.Errorf("error authorizing security group ingress: %v", err) + if _, err = a.ec2.AuthorizeSecurityGroupIngress(context.TODO(), irInput); err != nil { + a.fail() + return fmt.Errorf("error authorizing security group ingress: %v", err) } - // Create an EC2 instance - fmt.Printf("Creating EC2 instance with image %s\n", *a.Spec.Image.ImageId) + a.done() + return nil +} + +// createEC2Instance creates an EC2 instance with proper Network configuration +func (a *Client) createEC2Instance(cache *AWS) error { + a.log.Wg.Add(1) + go a.log.Loading("Creating EC2 instance") + instanceIn := &ec2.RunInstancesInput{ ImageId: a.Spec.Image.ImageId, InstanceType: types.InstanceType(a.Spec.Instance.Type), @@ -231,9 +301,9 @@ func (a *Client) create() (*AWS, error) { DeleteOnTermination: &yes, DeviceIndex: aws.Int32(0), Groups: []string{ - *sgOutput.GroupId, + cache.SecurityGroupid, }, - SubnetId: subnetOutput.Subnet.SubnetId, + SubnetId: aws.String(cache.Subnetid), }, }, KeyName: aws.String(a.Spec.Auth.KeyName), @@ -246,7 +316,8 @@ func (a *Client) create() (*AWS, error) { } instanceOut, err := a.ec2.RunInstances(context.Background(), instanceIn) if err != nil { - return &cache, fmt.Errorf("error creating instance: %v", err) + a.fail() + return fmt.Errorf("error creating instance: %v", err) } cache.Instanceid = *instanceOut.Instances[0].InstanceId @@ -257,12 +328,12 @@ func (a *Client) create() (*AWS, error) { }, } waiter := ec2.NewInstanceRunningWaiter(a.ec2, waiterOptions...) - fmt.Printf("Waiting for instance %s to be in running state\n", cache.Instanceid) - err = waiter.Wait(context.Background(), &ec2.DescribeInstancesInput{ + + if err = waiter.Wait(context.Background(), &ec2.DescribeInstancesInput{ InstanceIds: []string{*instanceOut.Instances[0].InstanceId}, - }, 5*time.Minute, waiterOptions...) - if err != nil { - return &cache, fmt.Errorf("error waiting for instance to be in running state: %v", err) + }, 5*time.Minute, waiterOptions...); err != nil { + a.fail() + return fmt.Errorf("error waiting for instance to be in running state: %v", err) } // Describe instance now that is running @@ -270,16 +341,11 @@ func (a *Client) create() (*AWS, error) { InstanceIds: []string{*instanceOut.Instances[0].InstanceId}, }) if err != nil { - return &cache, fmt.Errorf("error describing instances: %v", err) + a.fail() + return fmt.Errorf("error describing instances: %v", err) } - cache.PublicDnsName = *instanceRunning.Reservations[0].Instances[0].PublicDnsName - // Save objects ID's into a cache file - err = a.updateAvailableCondition(*a.Environment, &cache) - if err != nil { - return &cache, fmt.Errorf("error creating cache file: %v", err) - } - - return &cache, nil + a.done() + return nil } diff --git a/pkg/provider/aws/delete.go b/pkg/provider/aws/delete.go index daeb2d4e..3b78bc25 100644 --- a/pkg/provider/aws/delete.go +++ b/pkg/provider/aws/delete.go @@ -48,7 +48,10 @@ func (a *Client) delete(cache *AWS) error { if err != nil { return fmt.Errorf("error deleting instance: %v", err) } - fmt.Printf("Waiting for instance %s to be terminated\n", cache.Instanceid) + + a.log.Wg.Add(1) + go a.log.Loading("Waiting for instance %s to be terminated\n", cache.Instanceid) + waiterOptions := []func(*ec2.InstanceTerminatedWaiterOptions){ func(o *ec2.InstanceTerminatedWaiterOptions) { o.MaxDelay = 1 * time.Minute @@ -56,10 +59,10 @@ func (a *Client) delete(cache *AWS) error { }, } wait := ec2.NewInstanceTerminatedWaiter(a.ec2, waiterOptions...) - err = wait.Wait(context.Background(), &ec2.DescribeInstancesInput{ + if err := wait.Wait(context.Background(), &ec2.DescribeInstancesInput{ InstanceIds: []string{cache.Instanceid}, - }, 5*time.Minute, waiterOptions...) - if err != nil { + }, 5*time.Minute, waiterOptions...); err != nil { + a.fail() return fmt.Errorf("error waiting for instance to be terminated: %v", err) } @@ -69,6 +72,7 @@ func (a *Client) delete(cache *AWS) error { } _, err = a.ec2.DeleteSecurityGroup(context.Background(), deleteSecurityGroup) if err != nil { + a.fail() return fmt.Errorf("error deleting security group: %v", err) } @@ -78,6 +82,7 @@ func (a *Client) delete(cache *AWS) error { } _, err = a.ec2.DeleteSubnet(context.Background(), deleteSubnet) if err != nil { + a.fail() return fmt.Errorf("error deleting subnet: %v", err) } @@ -87,6 +92,7 @@ func (a *Client) delete(cache *AWS) error { } _, err = a.ec2.DeleteRouteTable(context.Background(), deleteRouteTable) if err != nil { + a.fail() return fmt.Errorf("error deleting route table: %v", err) } @@ -97,6 +103,7 @@ func (a *Client) delete(cache *AWS) error { } _, err = a.ec2.DetachInternetGateway(context.Background(), detachInternetGateway) if err != nil { + a.fail() return fmt.Errorf("error detaching Internet Gateway: %v", err) } @@ -106,6 +113,7 @@ func (a *Client) delete(cache *AWS) error { } _, err = a.ec2.DeleteInternetGateway(context.Background(), deleteInternetGatewayInput) if err != nil { + a.fail() return fmt.Errorf("error deleting Internet Gateway: %v", err) } @@ -115,8 +123,10 @@ func (a *Client) delete(cache *AWS) error { } _, err = a.ec2.DeleteVpc(context.Background(), dVpc) if err != nil { + a.fail() return fmt.Errorf("error deleting VPC: %v", err) } + a.done() return a.updateTerminatedCondition(*a.Environment, cache) } diff --git a/pkg/provider/aws/dryrun.go b/pkg/provider/aws/dryrun.go index e479160d..04786957 100644 --- a/pkg/provider/aws/dryrun.go +++ b/pkg/provider/aws/dryrun.go @@ -81,18 +81,34 @@ func (a *Client) checkImages() error { func (a *Client) DryRun() error { // Check if the desired instance type is supported in the region - fmt.Printf("Checking if instance type %s is supported in region %s\n", string(a.Spec.Instance.Type), a.Spec.Instance.Region) + a.log.Wg.Add(1) + go a.log.Loading("Checking if instance type %s is supported in region %s\n", string(a.Spec.Instance.Type), a.Spec.Instance.Region) err := a.checkInstanceTypes() if err != nil { + a.fail() return err } + a.done() // Check if the desired image is supported in the region - fmt.Printf("Checking if image %s is supported in region %s\n", *a.Spec.Instance.Image.ImageId, a.Spec.Instance.Region) + a.log.Wg.Add(1) + go a.log.Loading("Checking if image %s is supported in region %s\n", *a.Spec.Instance.Image.ImageId, a.Spec.Instance.Region) err = a.checkImages() if err != nil { + a.fail() return fmt.Errorf("failed to get images: %v", err) } + a.done() return nil } + +func (a *Client) done() { + a.log.Done <- struct{}{} + a.log.Wg.Wait() +} + +func (a *Client) fail() { + a.log.Fail <- struct{}{} + a.log.Wg.Wait() +} diff --git a/pkg/provisioner/dependency.go b/pkg/provisioner/dependency.go new file mode 100644 index 00000000..e3ece727 --- /dev/null +++ b/pkg/provisioner/dependency.go @@ -0,0 +1,192 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package provisioner + +import ( + "bytes" + + "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" + "github.com/NVIDIA/holodeck/pkg/provisioner/templates" +) + +type ProvisionFunc func(tpl *bytes.Buffer, env v1alpha1.Environment) error + +type DependencyGraph map[string][]string + +var ( + functions = map[string]ProvisionFunc{ + "kubeadm": kubeadm, + "kind": kind, + "microk8s": microk8s, + "containerd": containerd, + "crio": criO, + "docker": docker, + "nvdriver": nvdriver, + "containerToolkit": containerToolkit, + } +) + +// buildDependencyGraph builds a dependency graph based on the environment +// and returns a topologically sorted list of provisioning functions +// to be executed in an opinionated order +func buildDependencyGraph(env v1alpha1.Environment) ([]ProvisionFunc, error) { + // Predefined dependency graph + graph := DependencyGraph{ + "kubeadm": {}, + "containerToolkit": {"containerToolkit", "nvdriver"}, + } + + // for kubeadm + if env.Spec.ContainerRuntime.Name == v1alpha1.ContainerRuntimeContainerd { + graph["kubeadm"] = append(graph["kubeadm"], "containerd") + } else if env.Spec.ContainerRuntime.Name == v1alpha1.ContainerRuntimeCrio { + graph["kubeadm"] = append(graph["kubeadm"], "crio") + } else if env.Spec.ContainerRuntime.Name == v1alpha1.ContainerRuntimeDocker { + graph["kubeadm"] = append(graph["kubeadm"], "docker") + } else if env.Spec.ContainerRuntime.Name == v1alpha1.ContainerRuntimeNone { + // default to containerd if ContainerRuntime is empty + graph["kubeadm"] = append(graph["kubeadm"], "containerd") + } + + // if container toolkit is enabled then add container toolkit and nvdriver to kubeadm + if env.Spec.NVContainerToolKit.Install { + graph["kubeadm"] = append(graph["kubeadm"], "containerToolkit") + graph["kubeadm"] = append(graph["kubeadm"], "nvdriver") + } + + // for container toolkit + if env.Spec.ContainerRuntime.Name == v1alpha1.ContainerRuntimeContainerd { + graph["containerToolkit"] = append(graph["containerToolkit"], "containerd") + } else if env.Spec.ContainerRuntime.Name == v1alpha1.ContainerRuntimeCrio { + graph["containerToolkit"] = append(graph["containerToolkit"], "crio") + } else if env.Spec.ContainerRuntime.Name == v1alpha1.ContainerRuntimeDocker { + graph["containerToolkit"] = append(graph["containerToolkit"], "docker") + } else if env.Spec.ContainerRuntime.Name == v1alpha1.ContainerRuntimeNone { + // default to containerd if ContainerRuntime is empty + graph["kubeadm"] = append(graph["kubeadm"], "containerd") + } + + // user might request to install container toolkit without nvdriver for testing purpose + if env.Spec.NVDriver.Install { + graph["containerToolkit"] = append(graph["containerToolkit"], "nvdriver") + } + + ordered := []ProvisionFunc{} + // We go from up to bottom in the graph + // Kubernetes -> Container Toolkit -> Container Runtime -> NVDriver + // if a dependency is needed and not defined, we set an opinionated default + if env.Spec.Kubernetes.Install { + switch env.Spec.Kubernetes.KubernetesInstaller { + case "kubeadm": + for _, f := range graph["kubeadm"] { + ordered = append(ordered, functions[f]) + } + return ordered, nil + case "kind": + return []ProvisionFunc{docker, containerToolkit, nvdriver, kind}, nil + case "microk8s": + return []ProvisionFunc{microk8s}, nil + default: + // default to kubeadm if KubernetesInstaller is empty + for _, f := range graph["kubeadm"] { + ordered = append(ordered, functions[f]) + } + return ordered, nil + } + } + + // If no kubernetes is requested, we move to container-toolkit + if env.Spec.NVContainerToolKit.Install { + for _, f := range graph["containerToolkit"] { + ordered = append(ordered, functions[f]) + } + return ordered, nil + } + + // If no container-toolkit is requested, we move to container-runtime + if env.Spec.ContainerRuntime.Install { + switch env.Spec.ContainerRuntime.Name { + case "containerd": + ordered = append(ordered, functions["containerd"]) + return ordered, nil + case "crio": + return ordered, nil + case "docker": + return ordered, nil + default: + // default to containerd if ContainerRuntime.Name is empty + return ordered, nil + } + } + + // If no container-runtime is requested, we move to nvdriver + if env.Spec.NVDriver.Install { + ordered = append(ordered, functions["nvdriver"]) + return ordered, nil + } + + return nil, nil +} + +func nvdriver(tpl *bytes.Buffer, env v1alpha1.Environment) error { + nvdriver := templates.NewNvDriver() + return nvdriver.Execute(tpl, env) +} + +func docker(tpl *bytes.Buffer, env v1alpha1.Environment) error { + docker := templates.NewDocker(env) + return docker.Execute(tpl, env) +} + +func containerd(tpl *bytes.Buffer, env v1alpha1.Environment) error { + containerd := templates.NewContainerd(env) + return containerd.Execute(tpl, env) +} + +func criO(tpl *bytes.Buffer, env v1alpha1.Environment) error { + criO := templates.NewCriO(env) + return criO.Execute(tpl, env) +} + +func containerToolkit(tpl *bytes.Buffer, env v1alpha1.Environment) error { + containerToolkit := templates.NewContainerToolkit(env) + return containerToolkit.Execute(tpl, env) +} + +func kubeadm(tpl *bytes.Buffer, env v1alpha1.Environment) error { + kubernetes, err := templates.NewKubernetes(env) + if err != nil { + return err + } + return kubernetes.Execute(tpl, env) +} + +func microk8s(tpl *bytes.Buffer, env v1alpha1.Environment) error { + microk8s, err := templates.NewKubernetes(env) + if err != nil { + return err + } + return microk8s.Execute(tpl, env) +} + +func kind(tpl *bytes.Buffer, env v1alpha1.Environment) error { + kind, err := templates.NewKubernetes(env) + if err != nil { + return err + } + return kind.Execute(tpl, env) +} diff --git a/pkg/provisioner/dryrun.go b/pkg/provisioner/dryrun.go index 3269b36a..7a7a5267 100644 --- a/pkg/provisioner/dryrun.go +++ b/pkg/provisioner/dryrun.go @@ -21,34 +21,36 @@ import ( "strings" "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" + "github.com/NVIDIA/holodeck/internal/logger" ) -func Dryrun(env v1alpha1.Environment) error { +func Dryrun(log *logger.FunLogger, env v1alpha1.Environment) error { // Resolve dependencies from top to bottom - fmt.Printf("Resolving dependencies...\n") - // kubernetes -> container runtime -> node + log.Wg.Add(1) + + go log.Loading("Resolving dependencies \U0001F4E6 ...") + // Kubernetes -> Container Toolkit -> Container Runtime -> NVDriver if env.Spec.Kubernetes.Install { if !env.Spec.ContainerRuntime.Install { + log.Fail <- struct{}{} return fmt.Errorf("cannot install Kubernetes without a container runtime") } // check if env.Spec.Kubernetes.KubernetesVersion is in the format of vX.Y.Z - if !strings.HasPrefix(env.Spec.Kubernetes.KubernetesVersion, "v") { - return fmt.Errorf("Kubernetes version %s is not in the format of vX.Y.Z", env.Spec.Kubernetes.KubernetesVersion) + if env.Spec.Kubernetes.KubernetesInstaller == "kubeadm" && !strings.HasPrefix(env.Spec.Kubernetes.KubernetesVersion, "v") { + log.Fail <- struct{}{} + return fmt.Errorf("kubernetes version %s is not in the format of vX.Y.Z", env.Spec.Kubernetes.KubernetesVersion) } } if env.Spec.ContainerRuntime.Install && (env.Spec.ContainerRuntime.Name != v1alpha1.ContainerRuntimeContainerd && env.Spec.ContainerRuntime.Name != v1alpha1.ContainerRuntimeCrio && env.Spec.ContainerRuntime.Name != v1alpha1.ContainerRuntimeDocker) { + log.Fail <- struct{}{} return fmt.Errorf("container runtime %s not supported", env.Spec.ContainerRuntime.Name) } - if env.Spec.NVContainerToolKit.Install && !env.Spec.ContainerRuntime.Install { - return fmt.Errorf("cannot install NVContainer Toolkit without a container runtime") - } - if env.Spec.NVContainerToolKit.Install && !env.Spec.NVDriver.Install { - return fmt.Errorf("cannot install NVContainer Toolkit without the NVIDIA driver") - } + log.Done <- struct{}{} + log.Wg.Wait() return nil } diff --git a/pkg/provisioner/provisioner.go b/pkg/provisioner/provisioner.go index 8e82caf3..ae1207ab 100644 --- a/pkg/provisioner/provisioner.go +++ b/pkg/provisioner/provisioner.go @@ -32,6 +32,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" + "github.com/NVIDIA/holodeck/internal/logger" "github.com/NVIDIA/holodeck/pkg/provisioner/templates" ) @@ -44,32 +45,13 @@ type Provisioner struct { SessionManager *ssm.Client HostUrl string + KeyPath string tpl bytes.Buffer -} - -type Containerd struct { - Version string -} -type Crio struct { - Version string + log *logger.FunLogger } -type Docker struct { - Version string -} - -type ContainerToolkit struct { - ContainerRuntime string -} - -type NvDriver struct { - // Empty struct - // Placeholder to enable type assertion -} - -func New(keyPath, hostUrl string) (*Provisioner, error) { - fmt.Printf("Connecting to %s\n", hostUrl) +func New(log *logger.FunLogger, keyPath, hostUrl string) (*Provisioner, error) { client, err := connectOrDie(keyPath, hostUrl) if err != nil { return nil, fmt.Errorf("failed to connect to %s: %v", hostUrl, err) @@ -78,208 +60,72 @@ func New(keyPath, hostUrl string) (*Provisioner, error) { p := &Provisioner{ Client: client, HostUrl: hostUrl, + KeyPath: keyPath, tpl: bytes.Buffer{}, - } - - // Add script header and common functions to the script - if err := addScriptHeader(&p.tpl); err != nil { - return nil, fmt.Errorf("failed to add shebang to the script: %v", err) + log: log, } return p, nil } -func addScriptHeader(tpl *bytes.Buffer) error { - // Add shebang to the script - shebang := template.Must(template.New("shebang").Parse(Shebang)) - if err := shebang.Execute(tpl, nil); err != nil { - return fmt.Errorf("failed to add shebang to the script: %v", err) - } - // Add common functions to the script - commonFunctions := template.Must(template.New("common-functions").Parse(templates.CommonFunctions)) - if err := commonFunctions.Execute(tpl, nil); err != nil { - return fmt.Errorf("failed to add common functions to the script: %v", err) - } - return nil -} - -// resetConnection resets the ssh connection, and retries if it fails to connect -func (p *Provisioner) resetConnection(keyPath, hostUrl string) error { - var err error - - // Close the current ssh connection - if err := p.Client.Close(); err != nil { - return fmt.Errorf("failed to close ssh client: %v", err) - } - - // Create a new ssh connection - p.Client, err = connectOrDie(keyPath, hostUrl) - if err != nil { - return fmt.Errorf("failed to connect to %s: %v", p.HostUrl, err) - } - - return nil -} - -// createSshClient creates a ssh client, and retries if it fails to connect -func connectOrDie(keyPath, hostUrl string) (*ssh.Client, error) { - var client *ssh.Client - var err error - key, err := os.ReadFile(keyPath) - if err != nil { - return nil, fmt.Errorf("failed to read key file: %v", err) - } - signer, err := ssh.ParsePrivateKey(key) +func (p *Provisioner) Run(env v1alpha1.Environment) error { + graph, err := buildDependencyGraph(env) if err != nil { - return nil, fmt.Errorf("failed to parse private key: %v", err) - } - sshConfig := &ssh.ClientConfig{ - User: "ubuntu", - Auth: []ssh.AuthMethod{ - ssh.PublicKeys(signer), - }, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), + return fmt.Errorf("failed to build dependency graph: %v", err) } - connectionFailed := false - for i := 0; i < 20; i++ { - client, err = ssh.Dial("tcp", hostUrl+":22", sshConfig) - if err == nil { - return client, nil // Connection succeeded, return the client. + // kind-config + // Create kind config file if it is provided + if env.Spec.Kubernetes.KubernetesInstaller == "kind" && env.Spec.Kubernetes.KindConfig != "" { + if err := p.createKindConfig(env); err != nil { + return fmt.Errorf("failed to create kind config file: %v", err) } - fmt.Printf("Failed to connect to %s: %v\n", hostUrl, err) - connectionFailed = true - // Sleep for a brief moment before retrying. - // You can adjust the duration based on your requirements. - time.Sleep(1 * time.Second) } - if connectionFailed { - fmt.Printf("Failed to connect to %s after 10 retries, giving up\n", hostUrl) - return nil, err - } - - return client, nil -} - -func (p *Provisioner) Run(env v1alpha1.Environment) error { - // Read v1alpha1.Provisioner and execute the template based on the config - // the logic order is: - // 1. container-runtime - // 2. container-toolkit - // 3. nv-driver - // 4. kubernetes - // 5. kind-config (if kubernetes is installed via kind) - var containerRuntime string - - // 1. container-runtime - if env.Spec.ContainerRuntime.Install { - switch env.Spec.ContainerRuntime.Name { - case "docker": - containerRuntime = "docker" - dockerTemplate := template.Must(template.New("docker").Parse(templates.Docker)) - if env.Spec.ContainerRuntime.Version == "" { - env.Spec.ContainerRuntime.Version = "latest" - } - err := dockerTemplate.Execute(&p.tpl, &Docker{Version: env.Spec.ContainerRuntime.Version}) - if err != nil { - return fmt.Errorf("failed to execute docker template: %v", err) - } - case "crio": - containerRuntime = "crio" - crioTemplate := template.Must(template.New("crio").Parse(templates.Crio)) - err := crioTemplate.Execute(&p.tpl, &Crio{Version: env.Spec.ContainerRuntime.Version}) - if err != nil { - return fmt.Errorf("failed to execute crio template: %v", err) - } - case "containerd": - containerRuntime = "containerd" - containerdTemplate := template.Must(template.New("containerd").Parse(templates.Containerd)) - err := containerdTemplate.Execute(&p.tpl, &Containerd{Version: env.Spec.ContainerRuntime.Version}) - if err != nil { - return fmt.Errorf("failed to execute containerd template: %v", err) - } - default: - fmt.Printf("Unknown container runtime %s\n", env.Spec.ContainerRuntime.Name) - return nil + for _, node := range graph { + // Add script header and common functions to the script + if err := addScriptHeader(&p.tpl); err != nil { + return fmt.Errorf("failed to add shebang to the script: %v", err) } - } else if env.Spec.Kubernetes.KubernetesInstaller == "kind" { - // If kubernetes is installed via kind, we need to install docker - // as the container runtime - containerRuntime = "docker" - dockerTemplate := template.Must(template.New("docker").Parse(templates.Docker)) - err := dockerTemplate.Execute(&p.tpl, &Docker{Version: "latest"}) - if err != nil { - return fmt.Errorf("failed to execute docker template: %v", err) + // Execute the template for the dependency + if err := node(&p.tpl, env); err != nil { + return fmt.Errorf("failed to execute template: %w", err) } - - // And since we want to use KIND non-root mode, we need to add the user - // to the docker group so that the user can run docker commands without - // sudo + // Provision the instance if err := p.provision(); err != nil { return fmt.Errorf("failed to provision: %v", err) } - p.tpl.Reset() - if err := addScriptHeader(&p.tpl); err != nil { - return fmt.Errorf("failed to add shebang to the script: %v", err) - } - // close session to force docker group to take effect + // Reset the connection, this step is needed to make sure some configuration changes take effect + // e.g after installing docker, the user needs to be added to the docker group if err := p.resetConnection(env.Spec.Auth.PrivateKey, p.HostUrl); err != nil { - return fmt.Errorf("failed to reset ssh connection: %v", err) - } - } - - // 2. container-toolkit - // We need to install container-toolkit after container-runtime or skip it - // We also need to install container-toolkit if kubernetes is installed - // via kind - if env.Spec.NVContainerToolKit.Install && env.Spec.ContainerRuntime.Install || env.Spec.Kubernetes.KubernetesInstaller == "kind" { - containerToolkitTemplate := template.Must(template.New("container-toolkit").Parse(templates.ContainerToolkit)) - err := containerToolkitTemplate.Execute(&p.tpl, &ContainerToolkit{ContainerRuntime: containerRuntime}) - if err != nil { - return fmt.Errorf("failed to execute container-toolkit template: %v", err) - } - } - - // 3. nv-driver - // We need to install nv-driver if container-runtime if Kind is used - if env.Spec.NVDriver.Install || env.Spec.Kubernetes.KubernetesInstaller == "kind" { - nvDriverTemplate := template.Must(template.New("nv-driver").Parse(templates.NvDriver)) - err := nvDriverTemplate.Execute(&p.tpl, &NvDriver{}) - if err != nil { - return fmt.Errorf("failed to execute nv-driver template: %v", err) - } - } - - // 4. kubernetes - // Set opinionated defaults if not set - if env.Spec.Kubernetes.Install { - if env.Spec.Kubernetes.K8sEndpointHost == "" { - env.Spec.Kubernetes.K8sEndpointHost = p.HostUrl - } - err := templates.ExecuteKubernetes(&p.tpl, env) - if err != nil { - return fmt.Errorf("failed to execute kubernetes template: %v", err) + return fmt.Errorf("failed to reset connection: %v", err) } + // Clear the template buffer + p.tpl.Reset() } - // 5. kind-config - // Create kind config file if it is set - if env.Spec.Kubernetes.KubernetesInstaller == "kind" && env.Spec.Kubernetes.KindConfig != "" { - if err := p.createKindConfig(env); err != nil { - return fmt.Errorf("failed to create kind config file: %v", err) - } - } + return nil +} - // Provision the instance - if err := p.provision(); err != nil { - return fmt.Errorf("failed to provision: %v", err) +// resetConnection resets the ssh connection, and retries if it fails to connect +func (p *Provisioner) resetConnection(keyPath, hostUrl string) error { + // Close the current ssh connection + if err := p.Client.Close(); err != nil { + return fmt.Errorf("failed to close ssh client: %v", err) } return nil } func (p *Provisioner) provision() error { + var err error + + // Create a new ssh connection + p.Client, err = connectOrDie(p.KeyPath, p.HostUrl) + if err != nil { + return fmt.Errorf("failed to connect to %s: %v", p.HostUrl, err) + } + // Create a session session, err := p.Client.NewSession() if err != nil { @@ -299,6 +145,7 @@ func (p *Provisioner) provision() error { defer session.Close() script := p.tpl.String() + // run the script err = session.Start(script) if err != nil { @@ -358,3 +205,56 @@ func (p *Provisioner) createKindConfig(env v1alpha1.Environment) error { session.Wait() return nil } + +func addScriptHeader(tpl *bytes.Buffer) error { + // Add shebang to the script + shebang := template.Must(template.New("shebang").Parse(Shebang)) + if err := shebang.Execute(tpl, nil); err != nil { + return fmt.Errorf("failed to add shebang to the script: %v", err) + } + // Add common functions to the script + commonFunctions := template.Must(template.New("common-functions").Parse(templates.CommonFunctions)) + if err := commonFunctions.Execute(tpl, nil); err != nil { + return fmt.Errorf("failed to add common functions to the script: %v", err) + } + return nil +} + +// createSshClient creates a ssh client, and retries if it fails to connect +func connectOrDie(keyPath, hostUrl string) (*ssh.Client, error) { + var client *ssh.Client + var err error + key, err := os.ReadFile(keyPath) + if err != nil { + return nil, fmt.Errorf("failed to read key file: %v", err) + } + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %v", err) + } + sshConfig := &ssh.ClientConfig{ + User: "ubuntu", + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(signer), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + connectionFailed := false + for i := 0; i < 20; i++ { + client, err = ssh.Dial("tcp", hostUrl+":22", sshConfig) + if err == nil { + return client, nil // Connection succeeded, return the client. + } + connectionFailed = true + // Sleep for a brief moment before retrying. + // You can adjust the duration based on your requirements. + time.Sleep(1 * time.Second) + } + + if connectionFailed { + return nil, fmt.Errorf("failed to connect to %s after 10 retries, giving up", hostUrl) + } + + return client, nil +} diff --git a/pkg/provisioner/templates/common.go b/pkg/provisioner/templates/common.go index e65f995b..9493d4b1 100644 --- a/pkg/provisioner/templates/common.go +++ b/pkg/provisioner/templates/common.go @@ -16,6 +16,12 @@ package templates +import ( + "bytes" + + "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" +) + const CommonFunctions = ` export DEBIAN_FRONTEND=noninteractive @@ -74,3 +80,8 @@ with_retry() { return 1 } ` + +// Template is the interface that wraps the Execute method. +type Template interface { + Execute(tpl *bytes.Buffer, env v1alpha1.Environment) error +} diff --git a/pkg/provisioner/templates/container-toolkit.go b/pkg/provisioner/templates/container-toolkit.go index 4a1ffe16..ac992796 100644 --- a/pkg/provisioner/templates/container-toolkit.go +++ b/pkg/provisioner/templates/container-toolkit.go @@ -16,7 +16,17 @@ package templates -const ContainerToolkit = ` +import ( + "bytes" + "fmt" + "text/template" + + "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" +) + +const containerToolkitTemplate = ` + +# Install container toolkit curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \ && curl -s -L https://nvidia.github.io/libnvidia-container/stable/deb/nvidia-container-toolkit.list | \ sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \ @@ -30,3 +40,26 @@ sudo apt-get install -y nvidia-container-toolkit sudo nvidia-ctk runtime configure --runtime={{.ContainerRuntime}} --set-as-default sudo systemctl restart {{.ContainerRuntime}} ` + +type ContainerToolkit struct { + ContainerRuntime string +} + +func NewContainerToolkit(env v1alpha1.Environment) *ContainerToolkit { + return &ContainerToolkit{ + ContainerRuntime: string(env.Spec.ContainerRuntime.Name), + } +} + +func (t *ContainerToolkit) Execute(tpl *bytes.Buffer, env v1alpha1.Environment) error { + containerTlktTemplate := template.Must(template.New("container-toolkit").Parse(containerToolkitTemplate)) + if err := containerTlktTemplate.Execute(tpl, t); err != nil { + return fmt.Errorf("failed to execute container-toolkit template: %v", err) + } + + if err := containerTlktTemplate.Execute(tpl, t); err != nil { + return fmt.Errorf("failed to execute container-toolkit template: %v", err) + } + + return nil +} diff --git a/pkg/provisioner/templates/containerd.go b/pkg/provisioner/templates/containerd.go index 1e246a97..d78141b9 100644 --- a/pkg/provisioner/templates/containerd.go +++ b/pkg/provisioner/templates/containerd.go @@ -16,7 +16,16 @@ package templates -const Containerd = ` +import ( + "bytes" + "fmt" + "strings" + "text/template" + + "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" +) + +const containerdTemplate = ` : ${CONTAINERD_VERSION:={{.Version}}} # Install required packages @@ -47,3 +56,33 @@ sudo sed -i 's/SystemdCgroup \= false/SystemdCgroup \= true/g' /etc/containerd/c sudo systemctl restart containerd sudo systemctl enable containerd ` + +type Containerd struct { + Version string +} + +func NewContainerd(env v1alpha1.Environment) *Containerd { + var version string + + if env.Spec.ContainerRuntime.Version != "" { + if strings.HasPrefix(env.Spec.ContainerRuntime.Version, "v") { + version = strings.TrimPrefix(env.Spec.ContainerRuntime.Version, "v") + } else { + version = env.Spec.ContainerRuntime.Version + } + } else { + version = "1.6.27" + } + return &Containerd{ + Version: version, + } +} + +func (t *Containerd) Execute(tpl *bytes.Buffer, env v1alpha1.Environment) error { + containerdTemplate := template.Must(template.New("containerd").Parse(containerdTemplate)) + err := containerdTemplate.Execute(tpl, &Containerd{Version: env.Spec.ContainerRuntime.Version}) + if err != nil { + return fmt.Errorf("failed to execute containerd template: %v", err) + } + return nil +} diff --git a/pkg/provisioner/templates/crio.go b/pkg/provisioner/templates/crio.go index f3f0ff13..3539eb6d 100644 --- a/pkg/provisioner/templates/crio.go +++ b/pkg/provisioner/templates/crio.go @@ -16,7 +16,15 @@ package templates -const Crio = ` +import ( + "bytes" + "fmt" + "text/template" + + "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" +) + +const criOTemplate = ` : ${CRIO_VERSION:={{.Version}} # Add Cri-o repo @@ -34,3 +42,22 @@ systemctl daemon-reload systemctl restart crio systemctl enable crio ` + +type CriO struct { + Version string +} + +func NewCriO(env v1alpha1.Environment) *CriO { + return &CriO{ + Version: env.Spec.ContainerRuntime.Version, + } +} + +func (t *CriO) Execute(tpl *bytes.Buffer, env v1alpha1.Environment) error { + criOTemplate := template.Must(template.New("crio").Parse(criOTemplate)) + if err := criOTemplate.Execute(tpl, t); err != nil { + return fmt.Errorf("failed to execute crio template: %v", err) + } + + return nil +} diff --git a/pkg/provisioner/templates/docker.go b/pkg/provisioner/templates/docker.go index 4eedabb6..0cc03326 100644 --- a/pkg/provisioner/templates/docker.go +++ b/pkg/provisioner/templates/docker.go @@ -16,7 +16,15 @@ package templates -const Docker = ` +import ( + "bytes" + "fmt" + "text/template" + + "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" +) + +const dockerTemplate = ` # Based on https://docs.docker.com/engine/install/ubuntu/#install-using-the-repository : ${DOCKER_VERSION:={{.Version}}} @@ -66,3 +74,29 @@ sudo systemctl restart docker sudo usermod -aG docker $USER newgrp docker ` + +type Docker struct { + Version string +} + +func NewDocker(env v1alpha1.Environment) *Docker { + var version string + + if env.Spec.ContainerRuntime.Version != "" { + version = env.Spec.ContainerRuntime.Version + } else { + version = "latest" + } + return &Docker{ + Version: version, + } +} + +func (t *Docker) Execute(tpl *bytes.Buffer, env v1alpha1.Environment) error { + dockerTemplate := template.Must(template.New("docker").Parse(dockerTemplate)) + if err := dockerTemplate.Execute(tpl, t); err != nil { + return fmt.Errorf("failed to execute docker template: %v", err) + } + + return nil +} diff --git a/pkg/provisioner/templates/kubernetes.go b/pkg/provisioner/templates/kubernetes.go index a01e2ab0..3b31443b 100644 --- a/pkg/provisioner/templates/kubernetes.go +++ b/pkg/provisioner/templates/kubernetes.go @@ -25,21 +25,7 @@ import ( "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" ) -type Kubernetes struct { - Version string - Installer string - KubeletReleaseVersion string - Arch string - CniPluginsVersion string - CalicoVersion string - CrictlVersion string - K8sEndpointHost string - KubeAdmnFeatureGates string - // Kind exclusive - KindConfig string -} - -const KubernetesTemplate = ` +const KubeadmTemplate = ` # Install kubeadm, kubectl, and k8s-cni : ${K8S_VERSION:={{.Version}}} @@ -148,24 +134,46 @@ sudo chown -R $(id -u):$(id -g) $HOME/.kube/ with_retry 3 10s kind create cluster --name holodeck --config kind.yaml --kubeconfig="${HOME}/.kube/config" ` -func ExecuteKubernetes(tpl *bytes.Buffer, env v1alpha1.Environment) error { - kubernetesTemplate := new(template.Template) +const microk8sTemplate = ` - switch env.Spec.Kubernetes.KubernetesInstaller { - case "kubeadm": - kubernetesTemplate = template.Must(template.New("kubeadm").Parse(KubernetesTemplate)) - case "kind": - kubernetesTemplate = template.Must(template.New("kind").Parse(KindTemplate)) - default: - return fmt.Errorf("unknown kubernetes installer %s", env.Spec.Kubernetes.KubernetesInstaller) - } +# Install microk8s +sudo apt-get update + +sudo snap install microk8s --classic --channel={{.Version}} +sudo microk8s enable gpu dashboard dns registry +sudo usermod -a -G microk8s ubuntu +mkdir -p ~/.kube +sudo chown -f -R ubuntu ~/.kube +sudo microk8s config > ~/.kube/config +sudo chown -f -R ubuntu ~/.kube +sudo snap alias microk8s.kubectl kubectl + +echo "Microk8s {{.Version}} installed successfully" +echo "you can now access the cluster with:" +echo "ssh -i ubuntu@{{.K8sEndpointHost}}" +` + +type Kubernetes struct { + Version string + Installer string + KubeletReleaseVersion string + Arch string + CniPluginsVersion string + CalicoVersion string + CrictlVersion string + K8sEndpointHost string + KubeAdmnFeatureGates string + // Kind exclusive + KindConfig string +} +func NewKubernetes(env v1alpha1.Environment) (*Kubernetes, error) { kubernetes := &Kubernetes{ Version: env.Spec.Kubernetes.KubernetesVersion, } // check if env.Spec.Kubernetes.KubernetesVersion is in the format of vX.Y.Z // if not, set the default version - if !strings.HasPrefix(env.Spec.Kubernetes.KubernetesVersion, "v") { + if !strings.HasPrefix(env.Spec.Kubernetes.KubernetesVersion, "v") && env.Spec.Kubernetes.KubernetesInstaller != "microk8s" { fmt.Printf("Kubernetes version %s is not in the format of vX.Y.Z, setting default version v1.27.9\n", env.Spec.Kubernetes.KubernetesVersion) kubernetes.Version = "v1.27.9" } @@ -201,9 +209,27 @@ func ExecuteKubernetes(tpl *bytes.Buffer, env v1alpha1.Environment) error { kubernetes.KindConfig = env.Spec.Kubernetes.KindConfig } - err := kubernetesTemplate.Execute(tpl, kubernetes) + return kubernetes, nil +} + +func (k *Kubernetes) Execute(tpl *bytes.Buffer, env v1alpha1.Environment) error { + kubernetesTemplate := new(template.Template) + + switch env.Spec.Kubernetes.KubernetesInstaller { + case "kubeadm": + kubernetesTemplate = template.Must(template.New("kubeadm").Parse(KubeadmTemplate)) + case "kind": + kubernetesTemplate = template.Must(template.New("kind").Parse(KindTemplate)) + case "microk8s": + kubernetesTemplate = template.Must(template.New("microk8s").Parse(microk8sTemplate)) + default: + return fmt.Errorf("unknown kubernetes installer %s", env.Spec.Kubernetes.KubernetesInstaller) + } + + err := kubernetesTemplate.Execute(tpl, k) if err != nil { return fmt.Errorf("failed to execute kubernetes template: %v", err) } + return nil } diff --git a/pkg/provisioner/templates/nv-driver.go b/pkg/provisioner/templates/nv-driver.go index 2d1b105e..b650a09b 100644 --- a/pkg/provisioner/templates/nv-driver.go +++ b/pkg/provisioner/templates/nv-driver.go @@ -16,8 +16,16 @@ package templates +import ( + "bytes" + "fmt" + "text/template" + + "github.com/NVIDIA/holodeck/api/holodeck/v1alpha1" +) + // From https://docs.nvidia.com/datacenter/tesla/tesla-installation-notes/index.html#ubuntu-lts -const NvDriver = ` +const NvDriverTemplate = ` sudo apt-get update install_packages_with_retry linux-headers-$(uname -r) @@ -28,3 +36,19 @@ sudo dpkg -i cuda-keyring_1.0-1_all.deb with_retry 3 10s sudo apt-get update install_packages_with_retry cuda-drivers ` + +type NvDriver struct { +} + +func NewNvDriver() *NvDriver { + return &NvDriver{} +} + +func (t *NvDriver) Execute(tpl *bytes.Buffer, env v1alpha1.Environment) error { + nvDriverTemplate := template.Must(template.New("nv-driver").Parse(NvDriverTemplate)) + err := nvDriverTemplate.Execute(tpl, &NvDriver{}) + if err != nil { + return fmt.Errorf("failed to execute nv-driver template: %v", err) + } + return nil +}