From 3ea38ad2cf68effe45418bfd12c7eee54961d054 Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Tue, 20 Aug 2024 15:38:31 -0400 Subject: [PATCH 1/8] Do not strip torch patch versions * Support base images based on torch patch versions. --- pkg/config/compatibility.go | 50 +----------------------------- pkg/config/compatibility_test.go | 53 -------------------------------- pkg/dockerfile/base.go | 2 +- pkg/dockerfile/generator.go | 7 ----- 4 files changed, 2 insertions(+), 110 deletions(-) diff --git a/pkg/config/compatibility.go b/pkg/config/compatibility.go index a41185376e..600fadef42 100644 --- a/pkg/config/compatibility.go +++ b/pkg/config/compatibility.go @@ -97,11 +97,6 @@ var TFCompatibilityMatrix []TFCompatibility var torchCompatibilityMatrixData []byte var TorchCompatibilityMatrix []TorchCompatibility -// For minor Torch versions we use the latest patch version. -// This is semantically different to pip, which uses the .0 -// patch version. -var TorchMinorCompatibilityMatrix []TorchCompatibility - func init() { if err := json.Unmarshal(cudaBaseImagesData, &CUDABaseImages); err != nil { console.Fatalf("Failed to load embedded CUDA base images: %s", err) @@ -125,45 +120,6 @@ func init() { } } TorchCompatibilityMatrix = filteredTorchCompatibilityMatrix - TorchMinorCompatibilityMatrix = generateTorchMinorVersionCompatibilityMatrix(TorchCompatibilityMatrix) -} - -func generateTorchMinorVersionCompatibilityMatrix(matrix []TorchCompatibility) []TorchCompatibility { - minorMatrix := []TorchCompatibility{} - - // First sort compatibilities by Torch version descending - matrixByTorchDesc := make([]TorchCompatibility, len(matrix)) - copy(matrixByTorchDesc, matrix) - sort.Slice(matrixByTorchDesc, func(i, j int) bool { - return version.Greater(matrixByTorchDesc[i].Torch, matrixByTorchDesc[j].Torch) - }) - - // Then pick CUDA for the most recent patch versions - seenCUDATorchMinor := make(map[[2]string]bool) - - for _, compat := range matrixByTorchDesc { - cudaString := "" - if compat.CUDA != nil { - cudaString = *compat.CUDA - } - torchMinor := version.StripPatch(compat.Torch) - key := [2]string{cudaString, torchMinor} - - if seen := seenCUDATorchMinor[key]; !seen { - minorMatrix = append(minorMatrix, TorchCompatibility{ - Torch: torchMinor, - CUDA: compat.CUDA, - Pythons: compat.Pythons, - Torchvision: compat.Torchvision, - Torchaudio: compat.Torchaudio, - FindLinks: compat.FindLinks, - ExtraIndexURL: compat.ExtraIndexURL, - }) - seenCUDATorchMinor[key] = true - } - } - return minorMatrix - } func cudaVersionFromTorchPlusVersion(ver string) (string, string) { @@ -221,11 +177,7 @@ func cudasFromTorch(ver string) ([]string, error) { } } slices.Sort(cudas) - for _, compat := range TorchMinorCompatibilityMatrix { - if ver == compat.TorchVersion() && compat.CUDA != nil { - cudas = append(cudas, *compat.CUDA) - } - } + return cudas, nil } diff --git a/pkg/config/compatibility_test.go b/pkg/config/compatibility_test.go index 3bd6656839..022c9ea312 100644 --- a/pkg/config/compatibility_test.go +++ b/pkg/config/compatibility_test.go @@ -12,62 +12,9 @@ func TestLatestCuDNNForCUDA(t *testing.T) { require.Equal(t, "8", actual) } -func TestGenerateTorchMinorVersionCompatibilityMatrix(t *testing.T) { - matrix := []TorchCompatibility{{ - Torch: "2.0.0", - CUDA: nil, - Pythons: []string{"3.7", "3.8"}, - }, { - Torch: "2.0.0", - CUDA: stringp("12.0"), - Pythons: []string{"3.7", "3.8"}, - }, { - Torch: "2.0.1", - CUDA: stringp("12.0"), - Pythons: []string{"3.7", "3.8", "3.9"}, - }, { - Torch: "2.0.2", - CUDA: stringp("12.0"), - Pythons: []string{"3.8", "3.9"}, - }, { - Torch: "2.1.0", - CUDA: stringp("12.2"), - Pythons: []string{"3.8", "3.9"}, - }, { - Torch: "2.1.1", - CUDA: stringp("12.3"), - Pythons: []string{"3.9", "3.10"}, - }} - actual := generateTorchMinorVersionCompatibilityMatrix(matrix) - - expected := []TorchCompatibility{{ - Torch: "2.1", - CUDA: stringp("12.3"), - Pythons: []string{"3.9", "3.10"}, - }, { - Torch: "2.1", - CUDA: stringp("12.2"), - Pythons: []string{"3.8", "3.9"}, - }, { - Torch: "2.0", - CUDA: stringp("12.0"), - Pythons: []string{"3.8", "3.9"}, - }, { - Torch: "2.0", - CUDA: nil, - Pythons: []string{"3.7", "3.8"}, - }} - - require.Equal(t, expected, actual) -} - func TestCudasFromTorchWithCUVersionModifier(t *testing.T) { cudas, err := cudasFromTorch("2.0.1+cu118") require.GreaterOrEqual(t, len(cudas), 1) require.Equal(t, cudas[0], "11.8") require.Nil(t, err) } - -func stringp(s string) *string { - return &s -} diff --git a/pkg/dockerfile/base.go b/pkg/dockerfile/base.go index 510c7a3767..0579aed13e 100644 --- a/pkg/dockerfile/base.go +++ b/pkg/dockerfile/base.go @@ -110,7 +110,7 @@ func BaseImageConfigurations() []BaseImageConfiguration { cudaVersionsSet := make(map[string]bool) // Torch configs - for _, compat := range config.TorchMinorCompatibilityMatrix { + for _, compat := range config.TorchCompatibilityMatrix { for _, python := range compat.Pythons { // Only support fast cold boots for Torch with CUDA. diff --git a/pkg/dockerfile/generator.go b/pkg/dockerfile/generator.go index ee17620fa4..f9124b5390 100644 --- a/pkg/dockerfile/generator.go +++ b/pkg/dockerfile/generator.go @@ -594,13 +594,6 @@ func (g *Generator) determineBaseImageName() (string, error) { } torchVersion, _ := g.Config.TorchVersion() - torchVersion, changed, err = stripPatchVersion(torchVersion) - if err != nil { - return "", err - } - if changed { - console.Warnf("Stripping patch version from Torch version %s to %s", g.Config.Build.PythonVersion, pythonVersion) - } // validate that the base image configuration exists imageGenerator, err := NewBaseImageGenerator(cudaVersion, pythonVersion, torchVersion) From 46375e3f40f343891f8a338ea1a58f33daed0c30 Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Tue, 20 Aug 2024 16:00:01 -0400 Subject: [PATCH 2/8] Add new base images for torch 2.4.0 * Adds torch 2.4.0 compatibility in the matrix * Add CuDNN to CUDA 12.4.x compatibility --- pkg/config/cuda_base_images.json | 6 +-- pkg/config/torch_compatibility_matrix.json | 62 +++++++++++++++++++++- 2 files changed, 64 insertions(+), 4 deletions(-) diff --git a/pkg/config/cuda_base_images.json b/pkg/config/cuda_base_images.json index e92d3dd0e7..6386e05f17 100644 --- a/pkg/config/cuda_base_images.json +++ b/pkg/config/cuda_base_images.json @@ -2,14 +2,14 @@ { "Tag": "12.4.1-cudnn-devel-ubuntu22.04", "CUDA": "12.4.1", - "CuDNN": "", + "CuDNN": "9", "IsDevel": true, "Ubuntu": "22.04" }, { "Tag": "12.4.1-cudnn-devel-ubuntu20.04", "CUDA": "12.4.1", - "CuDNN": "", + "CuDNN": "9", "IsDevel": true, "Ubuntu": "20.04" }, @@ -286,4 +286,4 @@ "IsDevel": true, "Ubuntu": "16.04" } -] \ No newline at end of file +] diff --git a/pkg/config/torch_compatibility_matrix.json b/pkg/config/torch_compatibility_matrix.json index 4b30a3319b..ae7a01d7b5 100644 --- a/pkg/config/torch_compatibility_matrix.json +++ b/pkg/config/torch_compatibility_matrix.json @@ -1,4 +1,64 @@ [ + { + "Torch": "2.4.0", + "Torchvision": "0.19.0", + "Torchaudio": "2.4.0", + "FindLinks": "", + "ExtraIndexURL": "https://download.pytorch.org/whl/cu124", + "CUDA": "12.4", + "Pythons": [ + "3.10", + "3.11", + "3.12", + "3.8", + "3.9" + ] + }, + { + "Torch": "2.4.0", + "Torchvision": "0.19.0", + "Torchaudio": "2.4.0", + "FindLinks": "", + "ExtraIndexURL": "https://download.pytorch.org/whl/cu121", + "CUDA": "12.1", + "Pythons": [ + "3.10", + "3.11", + "3.12", + "3.8", + "3.9" + ] + }, + { + "Torch": "2.4.0", + "Torchvision": "0.19.0", + "Torchaudio": "2.4.0", + "FindLinks": "", + "ExtraIndexURL": "https://download.pytorch.org/whl/cu118", + "CUDA": "11.8", + "Pythons": [ + "3.10", + "3.11", + "3.12", + "3.8", + "3.9" + ] + }, + { + "Torch": "2.4.0", + "Torchvision": "0.19.0", + "Torchaudio": "2.4.0", + "FindLinks": "", + "ExtraIndexURL": "https://download.pytorch.org/whl/cpu", + "CUDA": null, + "Pythons": [ + "3.10", + "3.11", + "3.12", + "3.8", + "3.9" + ] + }, { "Torch": "2.3.1+cpu", "Torchvision": "0.18.1", @@ -1409,4 +1469,4 @@ "3.11" ] } -] \ No newline at end of file +] From 3bc0d85a23a8ca564ddbc8c6d45bb364823cd121 Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Tue, 20 Aug 2024 16:05:52 -0400 Subject: [PATCH 3/8] Exclude torchaudio from pip installs * Use our torch audio from the cog base images instead --- pkg/config/config.go | 4 ++++ pkg/dockerfile/generator.go | 3 +++ 2 files changed, 7 insertions(+) diff --git a/pkg/config/config.go b/pkg/config/config.go index 76af8a5a7a..c94f0a9a8a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -181,6 +181,10 @@ func (c *Config) TorchvisionVersion() (string, bool) { return c.pythonPackageVersion("torchvision") } +func (c *Config) TorchaudioVersion() (string, bool) { + return c.pythonPackageVersion("torchaudio") +} + func (c *Config) TensorFlowVersion() (string, bool) { return c.pythonPackageVersion("tensorflow") } diff --git a/pkg/dockerfile/generator.go b/pkg/dockerfile/generator.go index f9124b5390..ec5e9cf52b 100644 --- a/pkg/dockerfile/generator.go +++ b/pkg/dockerfile/generator.go @@ -401,6 +401,9 @@ func (g *Generator) pipInstalls() (string, error) { if torchvisionVersion, ok := g.Config.TorchvisionVersion(); ok { excludePackages = append(excludePackages, "torchvision=="+torchvisionVersion) } + if torchaudioVersion, ok := g.Config.TorchaudioVersion(); ok { + excludePackages = append(excludePackages, "torchaudio=="+torchaudioVersion) + } g.pythonRequirementsContents, err = g.Config.PythonRequirementsForArch(g.GOOS, g.GOARCH, excludePackages) if err != nil { return "", err From 521a5196c748c37aed5973ddc06cc7d29139c9b9 Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Tue, 20 Aug 2024 17:03:12 -0400 Subject: [PATCH 4/8] Remove version modifiers from base image name * Do not forward cu118 et al version modifiers --- pkg/dockerfile/base.go | 2 +- pkg/dockerfile/base_test.go | 5 +++ pkg/dockerfile/generator_test.go | 54 +++++++++++++++++++++++++++++--- pkg/util/version/version.go | 5 +++ pkg/util/version/version_test.go | 7 +++++ 5 files changed, 68 insertions(+), 5 deletions(-) diff --git a/pkg/dockerfile/base.go b/pkg/dockerfile/base.go index 0579aed13e..6ad909b0dc 100644 --- a/pkg/dockerfile/base.go +++ b/pkg/dockerfile/base.go @@ -226,7 +226,7 @@ func BaseImageName(cudaVersion string, pythonVersion string, torchVersion string components = append(components, "python"+version.StripPatch(pythonVersion)) } if torchVersion != "" { - components = append(components, "torch"+version.StripPatch(torchVersion)) + components = append(components, "torch"+version.StripModifier(torchVersion)) } tag := strings.Join(components, "-") diff --git a/pkg/dockerfile/base_test.go b/pkg/dockerfile/base_test.go index 25dd38bd23..aa26d699ba 100644 --- a/pkg/dockerfile/base_test.go +++ b/pkg/dockerfile/base_test.go @@ -28,3 +28,8 @@ func TestBaseImageName(t *testing.T) { require.Equal(t, tt.expected, actual) } } + +func TestBaseImageNameWithVersionModifier(t *testing.T) { + actual := BaseImageName("12.1", "3.8", "2.0.1+cu118") + require.Equal(t, "r8.im/cog-base:cuda12.1-python3.8-torch2.0.1", actual) +} diff --git a/pkg/dockerfile/generator_test.go b/pkg/dockerfile/generator_test.go index 77e73e8ff5..b7c2824aae 100644 --- a/pkg/dockerfile/generator_test.go +++ b/pkg/dockerfile/generator_test.go @@ -580,17 +580,17 @@ predict: predict.py:Predictor _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") require.NoError(t, err) - expected := `#syntax=docker/dockerfile:1.4 + expected := fmt.Sprintf(`#syntax=docker/dockerfile:1.4 FROM r8.im/replicate/cog-test-weights AS weights -FROM r8.im/cog-base:cuda11.8-python3.12-torch2.3 +FROM r8.im/cog-base:cuda11.8-python3.12-torch%s RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy cowsay && rm -rf /var/lib/apt/lists/* -COPY ` + gen.relativeTmpDir + `/requirements.txt /tmp/requirements.txt +COPY `+gen.relativeTmpDir+`/requirements.txt /tmp/requirements.txt RUN pip install -r /tmp/requirements.txt RUN cowsay moo WORKDIR /src EXPOSE 5000 CMD ["python", "-m", "cog.server.http"] -COPY . /src` +COPY . /src`, torchVersion) require.Equal(t, expected, actual) @@ -599,3 +599,49 @@ COPY . /src` require.Equal(t, "pandas==2.0.3", string(requirements)) } } + +func TestGenerateTorchWithStrippedModifiedVersion(t *testing.T) { + tmpDir := t.TempDir() + + yaml := ` +build: + gpu: true + cuda: "11.8" + system_packages: + - ffmpeg + - cowsay + python_packages: + - torch==2.3.1+cu118 + - pandas==2.0.3 + run: + - "cowsay moo" +predict: predict.py:Predictor +` + conf, err := config.FromYAML([]byte(yaml)) + require.NoError(t, err) + require.NoError(t, conf.ValidateAndComplete("")) + + gen, err := NewGenerator(conf, tmpDir) + require.NoError(t, err) + gen.SetUseCogBaseImage(true) + _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") + require.NoError(t, err) + + expected := `#syntax=docker/dockerfile:1.4 +FROM r8.im/replicate/cog-test-weights AS weights +FROM r8.im/cog-base:cuda11.8-python3.12-torch2.3.1 +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy cowsay && rm -rf /var/lib/apt/lists/* +COPY ` + gen.relativeTmpDir + `/requirements.txt /tmp/requirements.txt +RUN pip install -r /tmp/requirements.txt +RUN cowsay moo +WORKDIR /src +EXPOSE 5000 +CMD ["python", "-m", "cog.server.http"] +COPY . /src` + + require.Equal(t, expected, actual) + + requirements, err := os.ReadFile(path.Join(gen.tmpDir, "requirements.txt")) + require.NoError(t, err) + require.Equal(t, "pandas==2.0.3", string(requirements)) +} diff --git a/pkg/util/version/version.go b/pkg/util/version/version.go index 7c3c6afeab..8883fff0aa 100644 --- a/pkg/util/version/version.go +++ b/pkg/util/version/version.go @@ -115,3 +115,8 @@ func StripPatch(v string) string { ver := MustVersion(v) return fmt.Sprintf("%d.%d", ver.Major, ver.Minor) } + +func StripModifier(v string) string { + modifierSplit := strings.Split(v, "+") + return modifierSplit[0] +} diff --git a/pkg/util/version/version_test.go b/pkg/util/version/version_test.go index af83caafac..ac07d19afb 100644 --- a/pkg/util/version/version_test.go +++ b/pkg/util/version/version_test.go @@ -60,3 +60,10 @@ func TestVersionGreater(t *testing.T) { require.Equal(t, tt.greater, Greater(tt.v1, tt.v2), "%s is %sgreater than %s", tt.v1, not, tt.v2) } } + +func TestVersionStripModifier(t *testing.T) { + version := "2.3.1" + versionWithModifier := version + "+cu118" + versionWithoutModifier := StripModifier(versionWithModifier) + require.Equal(t, versionWithoutModifier, version) +} From 43f0d662e9f17201fcd275e64aa80130e570f5c9 Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Tue, 20 Aug 2024 17:38:02 -0400 Subject: [PATCH 5/8] Set useCogBaseImage to false in base-image * We never want to useCogBaseImage when generating base images. --- pkg/dockerfile/base.go | 2 ++ pkg/dockerfile/base_test.go | 13 +++++++++++++ 2 files changed, 15 insertions(+) diff --git a/pkg/dockerfile/base.go b/pkg/dockerfile/base.go index 6ad909b0dc..a22c15668a 100644 --- a/pkg/dockerfile/base.go +++ b/pkg/dockerfile/base.go @@ -180,6 +180,8 @@ func (g *BaseImageGenerator) GenerateDockerfile() (string, error) { if err != nil { return "", err } + useCogBaseImage := false + generator.useCogBaseImage = &useCogBaseImage dockerfile, err := generator.generateInitialSteps() if err != nil { diff --git a/pkg/dockerfile/base_test.go b/pkg/dockerfile/base_test.go index aa26d699ba..f05e0396ef 100644 --- a/pkg/dockerfile/base_test.go +++ b/pkg/dockerfile/base_test.go @@ -1,6 +1,7 @@ package dockerfile import ( + "strings" "testing" "github.com/stretchr/testify/require" @@ -33,3 +34,15 @@ func TestBaseImageNameWithVersionModifier(t *testing.T) { actual := BaseImageName("12.1", "3.8", "2.0.1+cu118") require.Equal(t, "r8.im/cog-base:cuda12.1-python3.8-torch2.0.1", actual) } + +func TestGenerateDockerfile(t *testing.T) { + generator, err := NewBaseImageGenerator( + "12.1", + "3.8", + "2.1.0", + ) + require.NoError(t, err) + dockerfile, err := generator.GenerateDockerfile() + require.NoError(t, err) + require.True(t, strings.Contains(dockerfile, "FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04")) +} From a73a37c0ceef008854b27eb6e490c9e16bc8f82a Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Wed, 21 Aug 2024 14:51:23 -0400 Subject: [PATCH 6/8] Fix more torch patch version bugs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Make sure the modifier is always stripped from torch (we don’t need it anymore) * Support major/minor versions of torch that resolve to latest patch version * Add the concept of a set Patch to version, that allows logic to know whether the patch version is explicitly set or not. --- pkg/config/compatibility.go | 10 ++++------ pkg/dockerfile/base.go | 22 ++++++++++++++++------ pkg/dockerfile/base_test.go | 16 ++++++++++++++++ pkg/dockerfile/generator_test.go | 10 ++++++++-- pkg/util/version/version.go | 28 +++++++++++++++++++++++----- pkg/util/version/version_test.go | 12 ++++++++++++ 6 files changed, 79 insertions(+), 19 deletions(-) diff --git a/pkg/config/compatibility.go b/pkg/config/compatibility.go index 600fadef42..6fca98dcae 100644 --- a/pkg/config/compatibility.go +++ b/pkg/config/compatibility.go @@ -61,13 +61,11 @@ type TorchCompatibility struct { } func (c *TorchCompatibility) TorchVersion() string { - parts := strings.Split(c.Torch, "+") - return parts[0] + return version.StripModifier(c.Torch) } func (c *TorchCompatibility) TorchvisionVersion() string { - parts := strings.Split(c.Torchvision, "+") - return parts[0] + return version.StripModifier(c.Torchvision) } type CUDABaseImage struct { @@ -164,7 +162,7 @@ func cudasFromTorch(ver string) ([]string, error) { if compat.CUDA == nil { continue } - if ver == compat.TorchVersion() && *compat.CUDA == cudaVer { + if version.Matches(ver, compat.TorchVersion()) && *compat.CUDA == cudaVer { cudas = append(cudas, *compat.CUDA) return cudas, nil } @@ -172,7 +170,7 @@ func cudasFromTorch(ver string) ([]string, error) { } for _, compat := range TorchCompatibilityMatrix { - if ver == compat.TorchVersion() && compat.CUDA != nil { + if version.Matches(ver, compat.TorchVersion()) && compat.CUDA != nil { cudas = append(cudas, *compat.CUDA) } } diff --git a/pkg/dockerfile/base.go b/pkg/dockerfile/base.go index a22c15668a..9df09ebdc0 100644 --- a/pkg/dockerfile/base.go +++ b/pkg/dockerfile/base.go @@ -120,7 +120,7 @@ func BaseImageConfigurations() []BaseImageConfiguration { } cuda := *compat.CUDA - torch := version.StripPatch(compat.Torch) + torch := compat.Torch conf := BaseImageConfiguration{ CUDAVersion: cuda, PythonVersion: python, @@ -158,7 +158,8 @@ func BaseImageConfigurations() []BaseImageConfiguration { } func NewBaseImageGenerator(cudaVersion string, pythonVersion string, torchVersion string) (*BaseImageGenerator, error) { - if BaseImageConfigurationExists(cudaVersion, pythonVersion, torchVersion) { + valid, cudaVersion, pythonVersion, torchVersion := BaseImageConfigurationExists(cudaVersion, pythonVersion, torchVersion) + if valid { return &BaseImageGenerator{cudaVersion, pythonVersion, torchVersion}, nil } printNone := func(s string) string { @@ -239,7 +240,8 @@ func BaseImageName(cudaVersion string, pythonVersion string, torchVersion string return BaseImageRegistry + "/cog-base:" + tag } -func BaseImageConfigurationExists(cudaVersion, pythonVersion, torchVersion string) bool { +func BaseImageConfigurationExists(cudaVersion, pythonVersion, torchVersion string) (bool, string, string, string) { + compatibleTorchVersion := "" for _, conf := range BaseImageConfigurations() { // Check CUDA version compatibility if !isVersionCompatible(conf.CUDAVersion, cudaVersion) { @@ -256,14 +258,22 @@ func BaseImageConfigurationExists(cudaVersion, pythonVersion, torchVersion strin continue } - return true + if compatibleTorchVersion == "" || version.Greater(conf.TorchVersion, compatibleTorchVersion) { + compatibleTorchVersion = version.StripModifier(conf.TorchVersion) + } + } + + valid := (torchVersion != "" && compatibleTorchVersion != "") || torchVersion == "" + if valid { + torchVersion = compatibleTorchVersion } - return false + + return valid, cudaVersion, pythonVersion, torchVersion } func isVersionCompatible(confVersion, requestedVersion string) bool { if confVersion == "" || requestedVersion == "" { return confVersion == requestedVersion } - return version.Matches(confVersion, requestedVersion) + return version.Matches(requestedVersion, confVersion) } diff --git a/pkg/dockerfile/base_test.go b/pkg/dockerfile/base_test.go index f05e0396ef..a792de78ab 100644 --- a/pkg/dockerfile/base_test.go +++ b/pkg/dockerfile/base_test.go @@ -46,3 +46,19 @@ func TestGenerateDockerfile(t *testing.T) { require.NoError(t, err) require.True(t, strings.Contains(dockerfile, "FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04")) } + +func TestBaseImageConfigurationExists(t *testing.T) { + exists, _, _, torchVersion := BaseImageConfigurationExists("12.1", "3.9", "2.3") + require.True(t, exists) + require.Equal(t, "2.3.1", torchVersion) +} + +func TestBaseImageConfigurationExistsNoTorch(t *testing.T) { + exists, _, _, _ := BaseImageConfigurationExists("", "3.12", "") + require.True(t, exists) +} + +func TestIsVersionCompatible(t *testing.T) { + compatible := isVersionCompatible("2.3.1+cu121", "2.3") + require.True(t, compatible) +} diff --git a/pkg/dockerfile/generator_test.go b/pkg/dockerfile/generator_test.go index b7c2824aae..832be0886e 100644 --- a/pkg/dockerfile/generator_test.go +++ b/pkg/dockerfile/generator_test.go @@ -560,6 +560,7 @@ func TestGenerateFullGPUWithCogBaseImage(t *testing.T) { build: gpu: true cuda: "11.8" + python_version: "3.11" system_packages: - ffmpeg - cowsay @@ -580,9 +581,14 @@ predict: predict.py:Predictor _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") require.NoError(t, err) + // We add the patch version to the expected torch version + expectedTorchVersion := torchVersion + if torchVersion == "2.3" { + expectedTorchVersion = "2.3.1" + } expected := fmt.Sprintf(`#syntax=docker/dockerfile:1.4 FROM r8.im/replicate/cog-test-weights AS weights -FROM r8.im/cog-base:cuda11.8-python3.12-torch%s +FROM r8.im/cog-base:cuda11.8-python3.11-torch%s RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy cowsay && rm -rf /var/lib/apt/lists/* COPY `+gen.relativeTmpDir+`/requirements.txt /tmp/requirements.txt RUN pip install -r /tmp/requirements.txt @@ -590,7 +596,7 @@ RUN cowsay moo WORKDIR /src EXPOSE 5000 CMD ["python", "-m", "cog.server.http"] -COPY . /src`, torchVersion) +COPY . /src`, expectedTorchVersion) require.Equal(t, expected, actual) diff --git a/pkg/util/version/version.go b/pkg/util/version/version.go index 8883fff0aa..ca62150c6e 100644 --- a/pkg/util/version/version.go +++ b/pkg/util/version/version.go @@ -9,7 +9,7 @@ import ( type Version struct { Major int Minor int - Patch int + Patch *int Metadata string } @@ -32,7 +32,9 @@ func NewVersion(s string) (version *Version, err error) { } } if len(parts) >= 3 { - version.Patch, err = strconv.Atoi(parts[2]) + patch, err := strconv.Atoi(parts[2]) + version.Patch = new(int) + *version.Patch = patch if err != nil { return nil, fmt.Errorf("Invalid patch version %s: %w", parts[2], err) } @@ -59,7 +61,9 @@ func (v *Version) Greater(other *Version) bool { return true case v.Major == other.Major && v.Minor > other.Minor: return true - case v.Major == other.Major && v.Minor == other.Minor && v.Patch > other.Patch: + case v.Major == other.Major && + v.Minor == other.Minor && + v.PatchVersion() > other.PatchVersion(): return true default: return false @@ -67,7 +71,10 @@ func (v *Version) Greater(other *Version) bool { } func (v *Version) Equal(other *Version) bool { - return v.Major == other.Major && v.Minor == other.Minor && v.Patch == other.Patch && v.Metadata == other.Metadata + return v.Major == other.Major && + v.Minor == other.Minor && + v.PatchVersion() == other.PatchVersion() && + v.Metadata == other.Metadata } func (v *Version) GreaterOrEqual(other *Version) bool { @@ -78,6 +85,17 @@ func (v *Version) EqualMinor(other *Version) bool { return v.Major == other.Major && v.Minor == other.Minor } +func (v *Version) HasPatch() bool { + return v.Patch != nil +} + +func (v *Version) PatchVersion() int { + if v.Patch == nil { + return 0 + } + return *v.Patch +} + func Equal(v1 string, v2 string) bool { return MustVersion(v1).Equal(MustVersion(v2)) } @@ -100,7 +118,7 @@ func (v *Version) Matches(other *Version) bool { return false case v.Minor != other.Minor: return false - case v.Patch != 0 && v.Patch != other.Patch: + case v.HasPatch() && other.HasPatch() && *v.Patch != *other.Patch: return false default: return true diff --git a/pkg/util/version/version_test.go b/pkg/util/version/version_test.go index ac07d19afb..e3d7f96e51 100644 --- a/pkg/util/version/version_test.go +++ b/pkg/util/version/version_test.go @@ -67,3 +67,15 @@ func TestVersionStripModifier(t *testing.T) { versionWithoutModifier := StripModifier(versionWithModifier) require.Equal(t, versionWithoutModifier, version) } + +func TestVersionMatches(t *testing.T) { + version := "2.3" + matchVersion := "2.3.2" + require.True(t, Matches(version, matchVersion)) +} + +func TestVersionMatchesModifier(t *testing.T) { + version := "2.3" + matchVersion := "2.3.2+cu118" + require.True(t, Matches(version, matchVersion)) +} From f695fa782246f59209c0c4aa959ba2a77c46f867 Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Wed, 21 Aug 2024 14:53:07 -0400 Subject: [PATCH 7/8] Install torchvision and torchaudio in base images * Add the compatible versions of torchvision and torchaudio to the base image python packages --- pkg/dockerfile/base.go | 30 +++++++++++++++++++++++++++++- pkg/dockerfile/base_test.go | 8 ++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/pkg/dockerfile/base.go b/pkg/dockerfile/base.go index 9df09ebdc0..fc5af2e9f6 100644 --- a/pkg/dockerfile/base.go +++ b/pkg/dockerfile/base.go @@ -211,7 +211,35 @@ func (g *BaseImageGenerator) makeConfig() (*config.Config, error) { func (g *BaseImageGenerator) pythonPackages() []string { if g.torchVersion != "" { - return []string{"torch==" + g.torchVersion} + pkgs := []string{"torch==" + g.torchVersion} + + // Find torchvision compatibility. + for _, compat := range config.TorchCompatibilityMatrix { + if len(compat.Torchvision) == 0 { + continue + } + if !version.Matches(g.torchVersion, compat.TorchVersion()) { + continue + } + + pkgs = append(pkgs, "torchvision=="+compat.Torchvision) + break + } + + // Find torchaudio compatibility. + for _, compat := range config.TorchCompatibilityMatrix { + if len(compat.Torchaudio) == 0 { + continue + } + if !version.Matches(g.torchVersion, compat.TorchVersion()) { + continue + } + + pkgs = append(pkgs, "torchaudio=="+compat.Torchaudio) + break + } + + return pkgs } return []string{} } diff --git a/pkg/dockerfile/base_test.go b/pkg/dockerfile/base_test.go index a792de78ab..69a0c06ee4 100644 --- a/pkg/dockerfile/base_test.go +++ b/pkg/dockerfile/base_test.go @@ -1,6 +1,7 @@ package dockerfile import ( + "reflect" "strings" "testing" @@ -62,3 +63,10 @@ func TestIsVersionCompatible(t *testing.T) { compatible := isVersionCompatible("2.3.1+cu121", "2.3") require.True(t, compatible) } + +func TestPythonPackages(t *testing.T) { + generator, err := NewBaseImageGenerator("12.1", "3.9", "2.1.0") + require.NoError(t, err) + pkgs := generator.pythonPackages() + require.Truef(t, reflect.DeepEqual(pkgs, []string{"torch==2.1.0", "torchvision==0.16.0", "torchaudio==2.1.0"}), "expected %v", pkgs) +} From f8a810864b43e986035a895e9921776e417e3e60 Mon Sep 17 00:00:00 2001 From: Will Sackfield Date: Wed, 21 Aug 2024 15:52:56 -0400 Subject: [PATCH 8/8] Resolve torch patch versions in base image name * When rendering the base image name, resolve the latest torch patch version if not supplied * Add BaseImageConfigurations that conform to the null CUDA --- pkg/dockerfile/base.go | 10 ++++++---- pkg/dockerfile/base_test.go | 12 +++++++++--- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/pkg/dockerfile/base.go b/pkg/dockerfile/base.go index fc5af2e9f6..d85a4c49c7 100644 --- a/pkg/dockerfile/base.go +++ b/pkg/dockerfile/base.go @@ -99,8 +99,6 @@ func (b BaseImageConfiguration) MarshalJSON() ([]byte, error) { } // BaseImageConfigurations returns a list of CUDA/Python/Torch versions -// with patch versions stripped out. Each version is greater or equal to -// MinimumCUDAVersion/MinimumPythonVersion/MinimumTorchVersion. func BaseImageConfigurations() []BaseImageConfiguration { configs := []BaseImageConfiguration{} @@ -113,9 +111,11 @@ func BaseImageConfigurations() []BaseImageConfiguration { for _, compat := range config.TorchCompatibilityMatrix { for _, python := range compat.Pythons { - // Only support fast cold boots for Torch with CUDA. - // Torch without CUDA is a rarely used edge case. if compat.CUDA == nil { + configs = append(configs, BaseImageConfiguration{ + PythonVersion: python, + TorchVersion: compat.Torch, + }) continue } @@ -249,6 +249,8 @@ func (g *BaseImageGenerator) runStatements() []config.RunItem { } func BaseImageName(cudaVersion string, pythonVersion string, torchVersion string) string { + _, cudaVersion, pythonVersion, torchVersion = BaseImageConfigurationExists(cudaVersion, pythonVersion, torchVersion) + components := []string{} if cudaVersion != "" { components = append(components, "cuda"+version.StripPatch(cudaVersion)) diff --git a/pkg/dockerfile/base_test.go b/pkg/dockerfile/base_test.go index 69a0c06ee4..3b2415b9fa 100644 --- a/pkg/dockerfile/base_test.go +++ b/pkg/dockerfile/base_test.go @@ -18,13 +18,13 @@ func TestBaseImageName(t *testing.T) { {"", "3.8", "", "r8.im/cog-base:python3.8"}, {"", "3.8", "2.1", - "r8.im/cog-base:python3.8-torch2.1"}, + "r8.im/cog-base:python3.8-torch2.1.2"}, {"12.1", "3.8", "", "r8.im/cog-base:cuda12.1-python3.8"}, {"12.1", "3.8", "2.1", - "r8.im/cog-base:cuda12.1-python3.8-torch2.1"}, + "r8.im/cog-base:cuda12.1-python3.8-torch2.1.2"}, {"12.1", "3.8", "2.1", - "r8.im/cog-base:cuda12.1-python3.8-torch2.1"}, + "r8.im/cog-base:cuda12.1-python3.8-torch2.1.2"}, } { actual := BaseImageName(tt.cuda, tt.python, tt.torch) require.Equal(t, tt.expected, actual) @@ -59,6 +59,12 @@ func TestBaseImageConfigurationExistsNoTorch(t *testing.T) { require.True(t, exists) } +func TestBaseImageConfigurationExistsNoCUDA(t *testing.T) { + exists, _, _, torchVersion := BaseImageConfigurationExists("", "3.8", "2.1") + require.True(t, exists) + require.Equal(t, "2.1.2", torchVersion) +} + func TestIsVersionCompatible(t *testing.T) { compatible := isVersionCompatible("2.3.1+cu121", "2.3") require.True(t, compatible)