Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support new base images #1890

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 5 additions & 55 deletions pkg/config/compatibility.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,11 @@ type TorchCompatibility struct {
}

func (c *TorchCompatibility) TorchVersion() string {
parts := strings.Split(c.Torch, "+")
return parts[0]
return version.StripModifier(c.Torch)
}

func (c *TorchCompatibility) TorchvisionVersion() string {
parts := strings.Split(c.Torchvision, "+")
return parts[0]
return version.StripModifier(c.Torchvision)
}

type CUDABaseImage struct {
Expand Down Expand Up @@ -97,11 +95,6 @@ var TFCompatibilityMatrix []TFCompatibility
var torchCompatibilityMatrixData []byte
var TorchCompatibilityMatrix []TorchCompatibility

// For minor Torch versions we use the latest patch version.
// This is semantically different to pip, which uses the .0
// patch version.
var TorchMinorCompatibilityMatrix []TorchCompatibility

func init() {
if err := json.Unmarshal(cudaBaseImagesData, &CUDABaseImages); err != nil {
console.Fatalf("Failed to load embedded CUDA base images: %s", err)
Expand All @@ -125,45 +118,6 @@ func init() {
}
}
TorchCompatibilityMatrix = filteredTorchCompatibilityMatrix
TorchMinorCompatibilityMatrix = generateTorchMinorVersionCompatibilityMatrix(TorchCompatibilityMatrix)
}

func generateTorchMinorVersionCompatibilityMatrix(matrix []TorchCompatibility) []TorchCompatibility {
minorMatrix := []TorchCompatibility{}

// First sort compatibilities by Torch version descending
matrixByTorchDesc := make([]TorchCompatibility, len(matrix))
copy(matrixByTorchDesc, matrix)
sort.Slice(matrixByTorchDesc, func(i, j int) bool {
return version.Greater(matrixByTorchDesc[i].Torch, matrixByTorchDesc[j].Torch)
})

// Then pick CUDA for the most recent patch versions
seenCUDATorchMinor := make(map[[2]string]bool)

for _, compat := range matrixByTorchDesc {
cudaString := ""
if compat.CUDA != nil {
cudaString = *compat.CUDA
}
torchMinor := version.StripPatch(compat.Torch)
key := [2]string{cudaString, torchMinor}

if seen := seenCUDATorchMinor[key]; !seen {
minorMatrix = append(minorMatrix, TorchCompatibility{
Torch: torchMinor,
CUDA: compat.CUDA,
Pythons: compat.Pythons,
Torchvision: compat.Torchvision,
Torchaudio: compat.Torchaudio,
FindLinks: compat.FindLinks,
ExtraIndexURL: compat.ExtraIndexURL,
})
seenCUDATorchMinor[key] = true
}
}
return minorMatrix

}

func cudaVersionFromTorchPlusVersion(ver string) (string, string) {
Expand Down Expand Up @@ -208,24 +162,20 @@ func cudasFromTorch(ver string) ([]string, error) {
if compat.CUDA == nil {
continue
}
if ver == compat.TorchVersion() && *compat.CUDA == cudaVer {
if version.Matches(ver, compat.TorchVersion()) && *compat.CUDA == cudaVer {
cudas = append(cudas, *compat.CUDA)
return cudas, nil
}
}
}

for _, compat := range TorchCompatibilityMatrix {
if ver == compat.TorchVersion() && compat.CUDA != nil {
if version.Matches(ver, compat.TorchVersion()) && compat.CUDA != nil {
cudas = append(cudas, *compat.CUDA)
}
}
slices.Sort(cudas)
for _, compat := range TorchMinorCompatibilityMatrix {
if ver == compat.TorchVersion() && compat.CUDA != nil {
cudas = append(cudas, *compat.CUDA)
}
}

return cudas, nil
}

Expand Down
53 changes: 0 additions & 53 deletions pkg/config/compatibility_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,62 +12,9 @@ func TestLatestCuDNNForCUDA(t *testing.T) {
require.Equal(t, "8", actual)
}

func TestGenerateTorchMinorVersionCompatibilityMatrix(t *testing.T) {
matrix := []TorchCompatibility{{
Torch: "2.0.0",
CUDA: nil,
Pythons: []string{"3.7", "3.8"},
}, {
Torch: "2.0.0",
CUDA: stringp("12.0"),
Pythons: []string{"3.7", "3.8"},
}, {
Torch: "2.0.1",
CUDA: stringp("12.0"),
Pythons: []string{"3.7", "3.8", "3.9"},
}, {
Torch: "2.0.2",
CUDA: stringp("12.0"),
Pythons: []string{"3.8", "3.9"},
}, {
Torch: "2.1.0",
CUDA: stringp("12.2"),
Pythons: []string{"3.8", "3.9"},
}, {
Torch: "2.1.1",
CUDA: stringp("12.3"),
Pythons: []string{"3.9", "3.10"},
}}
actual := generateTorchMinorVersionCompatibilityMatrix(matrix)

expected := []TorchCompatibility{{
Torch: "2.1",
CUDA: stringp("12.3"),
Pythons: []string{"3.9", "3.10"},
}, {
Torch: "2.1",
CUDA: stringp("12.2"),
Pythons: []string{"3.8", "3.9"},
}, {
Torch: "2.0",
CUDA: stringp("12.0"),
Pythons: []string{"3.8", "3.9"},
}, {
Torch: "2.0",
CUDA: nil,
Pythons: []string{"3.7", "3.8"},
}}

require.Equal(t, expected, actual)
}

func TestCudasFromTorchWithCUVersionModifier(t *testing.T) {
cudas, err := cudasFromTorch("2.0.1+cu118")
require.GreaterOrEqual(t, len(cudas), 1)
require.Equal(t, cudas[0], "11.8")
require.Nil(t, err)
}

func stringp(s string) *string {
return &s
}
4 changes: 4 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ func (c *Config) TorchvisionVersion() (string, bool) {
return c.pythonPackageVersion("torchvision")
}

func (c *Config) TorchaudioVersion() (string, bool) {
return c.pythonPackageVersion("torchaudio")
}

func (c *Config) TensorFlowVersion() (string, bool) {
return c.pythonPackageVersion("tensorflow")
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/config/cuda_base_images.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
{
"Tag": "12.4.1-cudnn-devel-ubuntu22.04",
"CUDA": "12.4.1",
"CuDNN": "",
"CuDNN": "9",
"IsDevel": true,
"Ubuntu": "22.04"
},
{
"Tag": "12.4.1-cudnn-devel-ubuntu20.04",
"CUDA": "12.4.1",
"CuDNN": "",
"CuDNN": "9",
"IsDevel": true,
"Ubuntu": "20.04"
},
Expand Down Expand Up @@ -286,4 +286,4 @@
"IsDevel": true,
"Ubuntu": "16.04"
}
]
]
62 changes: 61 additions & 1 deletion pkg/config/torch_compatibility_matrix.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,64 @@
[
{
"Torch": "2.4.0",
"Torchvision": "0.19.0",
"Torchaudio": "2.4.0",
"FindLinks": "",
"ExtraIndexURL": "https://download.pytorch.org/whl/cu124",
"CUDA": "12.4",
"Pythons": [
"3.10",
"3.11",
"3.12",
"3.8",
"3.9"
]
},
{
"Torch": "2.4.0",
"Torchvision": "0.19.0",
"Torchaudio": "2.4.0",
"FindLinks": "",
"ExtraIndexURL": "https://download.pytorch.org/whl/cu121",
"CUDA": "12.1",
"Pythons": [
"3.10",
"3.11",
"3.12",
"3.8",
"3.9"
]
},
{
"Torch": "2.4.0",
"Torchvision": "0.19.0",
"Torchaudio": "2.4.0",
"FindLinks": "",
"ExtraIndexURL": "https://download.pytorch.org/whl/cu118",
"CUDA": "11.8",
"Pythons": [
"3.10",
"3.11",
"3.12",
"3.8",
"3.9"
]
},
{
"Torch": "2.4.0",
"Torchvision": "0.19.0",
"Torchaudio": "2.4.0",
"FindLinks": "",
"ExtraIndexURL": "https://download.pytorch.org/whl/cpu",
"CUDA": null,
"Pythons": [
"3.10",
"3.11",
"3.12",
"3.8",
"3.9"
]
},
{
"Torch": "2.3.1+cpu",
"Torchvision": "0.18.1",
Expand Down Expand Up @@ -1409,4 +1469,4 @@
"3.11"
]
}
]
]
Loading
Loading