Skip to content

Commit

Permalink
fix(ir): make sure default value won't be replaced with empty value (#…
Browse files Browse the repository at this point in the history
…970)

fix(ir): default value

Signed-off-by: Keming <kemingyang@tensorchord.ai>

Signed-off-by: Keming <kemingyang@tensorchord.ai>
  • Loading branch information
kemingy authored Oct 7, 2022
1 parent c1ae887 commit 17fedb8
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 12 deletions.
11 changes: 4 additions & 7 deletions pkg/lang/frontend/starlark/install/install.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,16 @@ func ruleFuncSystemPackage(thread *starlark.Thread, _ *starlark.Builtin,

func ruleFuncCUDA(thread *starlark.Thread, _ *starlark.Builtin,
args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
var version, cudnn starlark.String
var version, cudnn string

if err := starlark.UnpackArgs(ruleCUDA, args, kwargs,
"version?", &version, "cudnn?", &cudnn); err != nil {
"version", &version, "cudnn?", &cudnn); err != nil {
return nil, err
}

versionStr := version.GoString()
cudnnStr := cudnn.GoString()

logger.Debugf("rule `%s` is invoked, version=%s, cudnn=%s",
ruleCUDA, versionStr, cudnnStr)
ir.CUDA(versionStr, cudnnStr)
ruleCUDA, version, cudnn)
ir.CUDA(version, cudnn)

return starlark.None, nil
}
Expand Down
3 changes: 1 addition & 2 deletions pkg/lang/frontend/starlark/universe/universe.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,9 @@ func RegisterBuildContext(buildContextDir string) {
func ruleFuncBase(thread *starlark.Thread, _ *starlark.Builtin,
args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
var os, language, image string
var useConda bool

if err := starlark.UnpackArgs(ruleBase, args, kwargs,
"os?", &os, "language?", &language, "image?", &image, "use_conda?", &useConda); err != nil {
"os?", &os, "language?", &language, "image?", &image); err != nil {
return nil, err
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/lang/ir/compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func NewGraph() *Graph {
Version: &langVersion,
},
CUDA: nil,
CUDNN: "8", // default version
CUDNN: CUDNNVersionDefault,
NumGPUs: -1,

PyPIPackages: []string{},
Expand Down
1 change: 1 addition & 0 deletions pkg/lang/ir/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const (
osDefault = "ubuntu20.04"
languageDefault = "python"
languageVersionDefault = "3"
CUDNNVersionDefault = "8"

aptSourceFilePath = "/etc/apt/sources.list"
pypiIndexFilePath = "/etc/pip.conf"
Expand Down
8 changes: 6 additions & 2 deletions pkg/lang/ir/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ func Base(os, language, image string) error {
Name: l,
Version: version,
}
DefaultGraph.OS = os
if len(os) > 0 {
DefaultGraph.OS = os
}
if image != "" {
DefaultGraph.Image = &image
}
Expand Down Expand Up @@ -66,7 +68,9 @@ func GPU(numGPUs int) {

func CUDA(version, cudnn string) {
DefaultGraph.CUDA = &version
DefaultGraph.CUDNN = cudnn
if len(cudnn) > 0 {
DefaultGraph.CUDNN = cudnn
}
}

func VSCodePlugins(plugins []string) error {
Expand Down

0 comments on commit 17fedb8

Please sign in to comment.