From d19f0252658696f76847904965369ee222662e61 Mon Sep 17 00:00:00 2001 From: Keming Date: Fri, 7 Oct 2022 13:48:01 +0800 Subject: [PATCH] fix(ir): default value Signed-off-by: Keming --- pkg/lang/frontend/starlark/install/install.go | 11 ++++------- pkg/lang/frontend/starlark/universe/universe.go | 3 +-- pkg/lang/ir/compile.go | 2 +- pkg/lang/ir/consts.go | 1 + pkg/lang/ir/interface.go | 8 ++++++-- 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/pkg/lang/frontend/starlark/install/install.go b/pkg/lang/frontend/starlark/install/install.go index 88ad3dfca..e9de9ef35 100644 --- a/pkg/lang/frontend/starlark/install/install.go +++ b/pkg/lang/frontend/starlark/install/install.go @@ -132,19 +132,16 @@ func ruleFuncSystemPackage(thread *starlark.Thread, _ *starlark.Builtin, func ruleFuncCUDA(thread *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { - var version, cudnn starlark.String + var version, cudnn string if err := starlark.UnpackArgs(ruleCUDA, args, kwargs, - "version?", &version, "cudnn?", &cudnn); err != nil { + "version", &version, "cudnn?", &cudnn); err != nil { return nil, err } - versionStr := version.GoString() - cudnnStr := cudnn.GoString() - logger.Debugf("rule `%s` is invoked, version=%s, cudnn=%s", - ruleCUDA, versionStr, cudnnStr) - ir.CUDA(versionStr, cudnnStr) + ruleCUDA, version, cudnn) + ir.CUDA(version, cudnn) return starlark.None, nil } diff --git a/pkg/lang/frontend/starlark/universe/universe.go b/pkg/lang/frontend/starlark/universe/universe.go index fd0c375d7..1f46b97f8 100644 --- a/pkg/lang/frontend/starlark/universe/universe.go +++ b/pkg/lang/frontend/starlark/universe/universe.go @@ -46,10 +46,9 @@ func RegisterBuildContext(buildContextDir string) { func ruleFuncBase(thread *starlark.Thread, _ *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { var os, language, image string - var useConda bool if err := starlark.UnpackArgs(ruleBase, args, kwargs, - "os?", &os, "language?", &language, "image?", &image, "use_conda?", &useConda); err != nil { + "os?", &os, "language?", &language, "image?", &image); err != nil { return nil, err } diff --git a/pkg/lang/ir/compile.go b/pkg/lang/ir/compile.go index dac155676..9a5ec4ece 100644 --- a/pkg/lang/ir/compile.go +++ b/pkg/lang/ir/compile.go @@ -49,7 +49,7 @@ func NewGraph() *Graph { Version: &langVersion, }, CUDA: nil, - CUDNN: "8", // default version + CUDNN: CUDNNVersionDefault, NumGPUs: -1, PyPIPackages: []string{}, diff --git a/pkg/lang/ir/consts.go b/pkg/lang/ir/consts.go index fb062feb3..5bb434097 100644 --- a/pkg/lang/ir/consts.go +++ b/pkg/lang/ir/consts.go @@ -20,6 +20,7 @@ const ( osDefault = "ubuntu20.04" languageDefault = "python" languageVersionDefault = "3" + CUDNNVersionDefault = "8" aptSourceFilePath = "/etc/apt/sources.list" pypiIndexFilePath = "/etc/pip.conf" diff --git a/pkg/lang/ir/interface.go b/pkg/lang/ir/interface.go index f3fa5ddc5..38fe419d4 100644 --- a/pkg/lang/ir/interface.go +++ b/pkg/lang/ir/interface.go @@ -30,7 +30,9 @@ func Base(os, language, image string) error { Name: l, Version: version, } - DefaultGraph.OS = os + if len(os) > 0 { + DefaultGraph.OS = os + } if image != "" { DefaultGraph.Image = &image } @@ -66,7 +68,9 @@ func GPU(numGPUs int) { func CUDA(version, cudnn string) { DefaultGraph.CUDA = &version - DefaultGraph.CUDNN = cudnn + if len(cudnn) > 0 { + DefaultGraph.CUDNN = cudnn + } } func VSCodePlugins(plugins []string) error {