diff --git a/pkg/template/config/config.go b/pkg/template/config/config.go index b6cf5b3..680814d 100644 --- a/pkg/template/config/config.go +++ b/pkg/template/config/config.go @@ -1,6 +1,8 @@ package config import ( + "fmt" + "github.com/footprintai/multikf/pkg/machine" "github.com/footprintai/multikf/pkg/template" ) @@ -81,10 +83,13 @@ func (t *DefaultTemplateConfig) AuditFileAbsolutePath() string { return t.auditFileAbsolutePath } -func (t *DefaultTemplateConfig) GetWorkerIDs() []int { - ids := make([]int, t.workerCount, t.workerCount) +func (t *DefaultTemplateConfig) GetWorkers() []template.Worker { + ids := make([]template.Worker, t.workerCount, t.workerCount) for i := 0; i < t.workerCount; i++ { - ids[i] = i + ids[i] = template.Worker{ + Id: fmt.Sprintf("%d", i), + UseGPU: t.GetGPUs() > 0, + } } return ids } diff --git a/pkg/template/kind_template.go b/pkg/template/kind_template.go index 856a2fe..709e9e9 100644 --- a/pkg/template/kind_template.go +++ b/pkg/template/kind_template.go @@ -36,7 +36,7 @@ type KindConfiger interface { GpuGetter ExportPortsGetter AuditEnabler - WorkerIDsGetter + WorkersGetter NodeLabelsGetter } @@ -52,7 +52,7 @@ func (k *KindFileTemplate) Populate(v interface{}) error { k.ExportPorts = c.GetExportPorts() k.AuditEnabled = c.AuditEnabled() k.AuditFileAbsolutePath = c.AuditFileAbsolutePath() - k.WorkerIDs = c.GetWorkerIDs() + k.Workers = c.GetWorkers() nodeLabels := c.GetNodeLabels() k.NodeLabels = make([]string, len(nodeLabels), len(nodeLabels)) @@ -72,7 +72,7 @@ type KindFileTemplate struct { ExportPorts []machine.ExportPortPair AuditEnabled bool AuditFileAbsolutePath string - WorkerIDs []int + Workers []Worker NodeLabels []string } @@ -130,9 +130,10 @@ nodes: containerPath: /etc/kubernetes/policies/audit-policy.yaml readOnly: true {{- end}} -{{- range .WorkerIDs }} +{{- range .Workers }} - role: worker image: kindest/node:v1.23.12@sha256:9402cf1330bbd3a0d097d2033fa489b2abe40d479cc5ef47d0b6a6960613148a + gpus: {{ .UseGPU}} {{- end}} networking: apiServerAddress: {{.KubeAPIIP}} diff --git a/pkg/template/kind_template_test.go b/pkg/template/kind_template_test.go index ea03c78..d58496c 100644 --- a/pkg/template/kind_template_test.go +++ b/pkg/template/kind_template_test.go @@ -8,6 +8,10 @@ import ( "github.com/stretchr/testify/assert" ) +var ( + _ KindConfiger = staticConfig{} +) + type staticConfig struct{} func (s staticConfig) GetName() string { @@ -34,8 +38,21 @@ func (s staticConfig) AuditFileAbsolutePath() string { return "" } -func (s staticConfig) GetWorkerIDs() []int { - return []int{1, 2, 3} +func (s staticConfig) GetWorkers() []Worker { + return []Worker{ + Worker{ + Id: "1", + UseGPU: true, + }, + Worker{ + Id: "2", + UseGPU: true, + }, + Worker{ + Id: "3", + UseGPU: true, + }, + } } func (s staticConfig) GetNodeLabels() []machine.NodeLabel { @@ -97,14 +114,20 @@ nodes: protocol: TCP - role: worker image: kindest/node:v1.23.12@sha256:9402cf1330bbd3a0d097d2033fa489b2abe40d479cc5ef47d0b6a6960613148a + gpus: true - role: worker image: kindest/node:v1.23.12@sha256:9402cf1330bbd3a0d097d2033fa489b2abe40d479cc5ef47d0b6a6960613148a + gpus: true - role: worker image: kindest/node:v1.23.12@sha256:9402cf1330bbd3a0d097d2033fa489b2abe40d479cc5ef47d0b6a6960613148a + gpus: true networking: apiServerAddress: 1.2.3.4 apiServerPort: 8443 ` +var ( + _ KindConfiger = auditConfig{} +) type auditConfig struct{} @@ -141,8 +164,8 @@ func (s auditConfig) AuditFileAbsolutePath() string { return "foo.bar.yaml" } -func (s auditConfig) GetWorkerIDs() []int { - return []int{} +func (s auditConfig) GetWorkers() []Worker { + return []Worker{} } func (s auditConfig) GetNodeLabels() []machine.NodeLabel { diff --git a/pkg/template/template.go b/pkg/template/template.go index 626b7f9..f40570c 100644 --- a/pkg/template/template.go +++ b/pkg/template/template.go @@ -50,8 +50,13 @@ type AuditEnabler interface { AuditFileAbsolutePath() string } -type WorkerIDsGetter interface { - GetWorkerIDs() []int +type WorkersGetter interface { + GetWorkers() []Worker +} + +type Worker struct { + Id string + UseGPU bool } type NodeLabelsGetter interface {