diff --git a/pkg/config/compatibility.go b/pkg/config/compatibility.go index a41185376e..6fca98dcae 100644 --- a/pkg/config/compatibility.go +++ b/pkg/config/compatibility.go @@ -61,13 +61,11 @@ type TorchCompatibility struct { } func (c *TorchCompatibility) TorchVersion() string { - parts := strings.Split(c.Torch, "+") - return parts[0] + return version.StripModifier(c.Torch) } func (c *TorchCompatibility) TorchvisionVersion() string { - parts := strings.Split(c.Torchvision, "+") - return parts[0] + return version.StripModifier(c.Torchvision) } type CUDABaseImage struct { @@ -97,11 +95,6 @@ var TFCompatibilityMatrix []TFCompatibility var torchCompatibilityMatrixData []byte var TorchCompatibilityMatrix []TorchCompatibility -// For minor Torch versions we use the latest patch version. -// This is semantically different to pip, which uses the .0 -// patch version. -var TorchMinorCompatibilityMatrix []TorchCompatibility - func init() { if err := json.Unmarshal(cudaBaseImagesData, &CUDABaseImages); err != nil { console.Fatalf("Failed to load embedded CUDA base images: %s", err) @@ -125,45 +118,6 @@ func init() { } } TorchCompatibilityMatrix = filteredTorchCompatibilityMatrix - TorchMinorCompatibilityMatrix = generateTorchMinorVersionCompatibilityMatrix(TorchCompatibilityMatrix) -} - -func generateTorchMinorVersionCompatibilityMatrix(matrix []TorchCompatibility) []TorchCompatibility { - minorMatrix := []TorchCompatibility{} - - // First sort compatibilities by Torch version descending - matrixByTorchDesc := make([]TorchCompatibility, len(matrix)) - copy(matrixByTorchDesc, matrix) - sort.Slice(matrixByTorchDesc, func(i, j int) bool { - return version.Greater(matrixByTorchDesc[i].Torch, matrixByTorchDesc[j].Torch) - }) - - // Then pick CUDA for the most recent patch versions - seenCUDATorchMinor := make(map[[2]string]bool) - - for _, compat := range matrixByTorchDesc { - cudaString := "" - if compat.CUDA != nil { - cudaString = *compat.CUDA - } - torchMinor := version.StripPatch(compat.Torch) - key := [2]string{cudaString, torchMinor} - - if seen := seenCUDATorchMinor[key]; !seen { - minorMatrix = append(minorMatrix, TorchCompatibility{ - Torch: torchMinor, - CUDA: compat.CUDA, - Pythons: compat.Pythons, - Torchvision: compat.Torchvision, - Torchaudio: compat.Torchaudio, - FindLinks: compat.FindLinks, - ExtraIndexURL: compat.ExtraIndexURL, - }) - seenCUDATorchMinor[key] = true - } - } - return minorMatrix - } func cudaVersionFromTorchPlusVersion(ver string) (string, string) { @@ -208,7 +162,7 @@ func cudasFromTorch(ver string) ([]string, error) { if compat.CUDA == nil { continue } - if ver == compat.TorchVersion() && *compat.CUDA == cudaVer { + if version.Matches(ver, compat.TorchVersion()) && *compat.CUDA == cudaVer { cudas = append(cudas, *compat.CUDA) return cudas, nil } @@ -216,16 +170,12 @@ func cudasFromTorch(ver string) ([]string, error) { } for _, compat := range TorchCompatibilityMatrix { - if ver == compat.TorchVersion() && compat.CUDA != nil { + if version.Matches(ver, compat.TorchVersion()) && compat.CUDA != nil { cudas = append(cudas, *compat.CUDA) } } slices.Sort(cudas) - for _, compat := range TorchMinorCompatibilityMatrix { - if ver == compat.TorchVersion() && compat.CUDA != nil { - cudas = append(cudas, *compat.CUDA) - } - } + return cudas, nil } diff --git a/pkg/config/compatibility_test.go b/pkg/config/compatibility_test.go index 3bd6656839..022c9ea312 100644 --- a/pkg/config/compatibility_test.go +++ b/pkg/config/compatibility_test.go @@ -12,62 +12,9 @@ func TestLatestCuDNNForCUDA(t *testing.T) { require.Equal(t, "8", actual) } -func TestGenerateTorchMinorVersionCompatibilityMatrix(t *testing.T) { - matrix := []TorchCompatibility{{ - Torch: "2.0.0", - CUDA: nil, - Pythons: []string{"3.7", "3.8"}, - }, { - Torch: "2.0.0", - CUDA: stringp("12.0"), - Pythons: []string{"3.7", "3.8"}, - }, { - Torch: "2.0.1", - CUDA: stringp("12.0"), - Pythons: []string{"3.7", "3.8", "3.9"}, - }, { - Torch: "2.0.2", - CUDA: stringp("12.0"), - Pythons: []string{"3.8", "3.9"}, - }, { - Torch: "2.1.0", - CUDA: stringp("12.2"), - Pythons: []string{"3.8", "3.9"}, - }, { - Torch: "2.1.1", - CUDA: stringp("12.3"), - Pythons: []string{"3.9", "3.10"}, - }} - actual := generateTorchMinorVersionCompatibilityMatrix(matrix) - - expected := []TorchCompatibility{{ - Torch: "2.1", - CUDA: stringp("12.3"), - Pythons: []string{"3.9", "3.10"}, - }, { - Torch: "2.1", - CUDA: stringp("12.2"), - Pythons: []string{"3.8", "3.9"}, - }, { - Torch: "2.0", - CUDA: stringp("12.0"), - Pythons: []string{"3.8", "3.9"}, - }, { - Torch: "2.0", - CUDA: nil, - Pythons: []string{"3.7", "3.8"}, - }} - - require.Equal(t, expected, actual) -} - func TestCudasFromTorchWithCUVersionModifier(t *testing.T) { cudas, err := cudasFromTorch("2.0.1+cu118") require.GreaterOrEqual(t, len(cudas), 1) require.Equal(t, cudas[0], "11.8") require.Nil(t, err) } - -func stringp(s string) *string { - return &s -} diff --git a/pkg/config/config.go b/pkg/config/config.go index 76af8a5a7a..c94f0a9a8a 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -181,6 +181,10 @@ func (c *Config) TorchvisionVersion() (string, bool) { return c.pythonPackageVersion("torchvision") } +func (c *Config) TorchaudioVersion() (string, bool) { + return c.pythonPackageVersion("torchaudio") +} + func (c *Config) TensorFlowVersion() (string, bool) { return c.pythonPackageVersion("tensorflow") } diff --git a/pkg/config/cuda_base_images.json b/pkg/config/cuda_base_images.json index e92d3dd0e7..6386e05f17 100644 --- a/pkg/config/cuda_base_images.json +++ b/pkg/config/cuda_base_images.json @@ -2,14 +2,14 @@ { "Tag": "12.4.1-cudnn-devel-ubuntu22.04", "CUDA": "12.4.1", - "CuDNN": "", + "CuDNN": "9", "IsDevel": true, "Ubuntu": "22.04" }, { "Tag": "12.4.1-cudnn-devel-ubuntu20.04", "CUDA": "12.4.1", - "CuDNN": "", + "CuDNN": "9", "IsDevel": true, "Ubuntu": "20.04" }, @@ -286,4 +286,4 @@ "IsDevel": true, "Ubuntu": "16.04" } -] \ No newline at end of file +] diff --git a/pkg/config/torch_compatibility_matrix.json b/pkg/config/torch_compatibility_matrix.json index 4b30a3319b..ae7a01d7b5 100644 --- a/pkg/config/torch_compatibility_matrix.json +++ b/pkg/config/torch_compatibility_matrix.json @@ -1,4 +1,64 @@ [ + { + "Torch": "2.4.0", + "Torchvision": "0.19.0", + "Torchaudio": "2.4.0", + "FindLinks": "", + "ExtraIndexURL": "https://download.pytorch.org/whl/cu124", + "CUDA": "12.4", + "Pythons": [ + "3.10", + "3.11", + "3.12", + "3.8", + "3.9" + ] + }, + { + "Torch": "2.4.0", + "Torchvision": "0.19.0", + "Torchaudio": "2.4.0", + "FindLinks": "", + "ExtraIndexURL": "https://download.pytorch.org/whl/cu121", + "CUDA": "12.1", + "Pythons": [ + "3.10", + "3.11", + "3.12", + "3.8", + "3.9" + ] + }, + { + "Torch": "2.4.0", + "Torchvision": "0.19.0", + "Torchaudio": "2.4.0", + "FindLinks": "", + "ExtraIndexURL": "https://download.pytorch.org/whl/cu118", + "CUDA": "11.8", + "Pythons": [ + "3.10", + "3.11", + "3.12", + "3.8", + "3.9" + ] + }, + { + "Torch": "2.4.0", + "Torchvision": "0.19.0", + "Torchaudio": "2.4.0", + "FindLinks": "", + "ExtraIndexURL": "https://download.pytorch.org/whl/cpu", + "CUDA": null, + "Pythons": [ + "3.10", + "3.11", + "3.12", + "3.8", + "3.9" + ] + }, { "Torch": "2.3.1+cpu", "Torchvision": "0.18.1", @@ -1409,4 +1469,4 @@ "3.11" ] } -] \ No newline at end of file +] diff --git a/pkg/dockerfile/base.go b/pkg/dockerfile/base.go index 510c7a3767..d85a4c49c7 100644 --- a/pkg/dockerfile/base.go +++ b/pkg/dockerfile/base.go @@ -99,8 +99,6 @@ func (b BaseImageConfiguration) MarshalJSON() ([]byte, error) { } // BaseImageConfigurations returns a list of CUDA/Python/Torch versions -// with patch versions stripped out. Each version is greater or equal to -// MinimumCUDAVersion/MinimumPythonVersion/MinimumTorchVersion. func BaseImageConfigurations() []BaseImageConfiguration { configs := []BaseImageConfiguration{} @@ -110,17 +108,19 @@ func BaseImageConfigurations() []BaseImageConfiguration { cudaVersionsSet := make(map[string]bool) // Torch configs - for _, compat := range config.TorchMinorCompatibilityMatrix { + for _, compat := range config.TorchCompatibilityMatrix { for _, python := range compat.Pythons { - // Only support fast cold boots for Torch with CUDA. - // Torch without CUDA is a rarely used edge case. if compat.CUDA == nil { + configs = append(configs, BaseImageConfiguration{ + PythonVersion: python, + TorchVersion: compat.Torch, + }) continue } cuda := *compat.CUDA - torch := version.StripPatch(compat.Torch) + torch := compat.Torch conf := BaseImageConfiguration{ CUDAVersion: cuda, PythonVersion: python, @@ -158,7 +158,8 @@ func BaseImageConfigurations() []BaseImageConfiguration { } func NewBaseImageGenerator(cudaVersion string, pythonVersion string, torchVersion string) (*BaseImageGenerator, error) { - if BaseImageConfigurationExists(cudaVersion, pythonVersion, torchVersion) { + valid, cudaVersion, pythonVersion, torchVersion := BaseImageConfigurationExists(cudaVersion, pythonVersion, torchVersion) + if valid { return &BaseImageGenerator{cudaVersion, pythonVersion, torchVersion}, nil } printNone := func(s string) string { @@ -180,6 +181,8 @@ func (g *BaseImageGenerator) GenerateDockerfile() (string, error) { if err != nil { return "", err } + useCogBaseImage := false + generator.useCogBaseImage = &useCogBaseImage dockerfile, err := generator.generateInitialSteps() if err != nil { @@ -208,7 +211,35 @@ func (g *BaseImageGenerator) makeConfig() (*config.Config, error) { func (g *BaseImageGenerator) pythonPackages() []string { if g.torchVersion != "" { - return []string{"torch==" + g.torchVersion} + pkgs := []string{"torch==" + g.torchVersion} + + // Find torchvision compatibility. + for _, compat := range config.TorchCompatibilityMatrix { + if len(compat.Torchvision) == 0 { + continue + } + if !version.Matches(g.torchVersion, compat.TorchVersion()) { + continue + } + + pkgs = append(pkgs, "torchvision=="+compat.Torchvision) + break + } + + // Find torchaudio compatibility. + for _, compat := range config.TorchCompatibilityMatrix { + if len(compat.Torchaudio) == 0 { + continue + } + if !version.Matches(g.torchVersion, compat.TorchVersion()) { + continue + } + + pkgs = append(pkgs, "torchaudio=="+compat.Torchaudio) + break + } + + return pkgs } return []string{} } @@ -218,6 +249,8 @@ func (g *BaseImageGenerator) runStatements() []config.RunItem { } func BaseImageName(cudaVersion string, pythonVersion string, torchVersion string) string { + _, cudaVersion, pythonVersion, torchVersion = BaseImageConfigurationExists(cudaVersion, pythonVersion, torchVersion) + components := []string{} if cudaVersion != "" { components = append(components, "cuda"+version.StripPatch(cudaVersion)) @@ -226,7 +259,7 @@ func BaseImageName(cudaVersion string, pythonVersion string, torchVersion string components = append(components, "python"+version.StripPatch(pythonVersion)) } if torchVersion != "" { - components = append(components, "torch"+version.StripPatch(torchVersion)) + components = append(components, "torch"+version.StripModifier(torchVersion)) } tag := strings.Join(components, "-") @@ -237,7 +270,8 @@ func BaseImageName(cudaVersion string, pythonVersion string, torchVersion string return BaseImageRegistry + "/cog-base:" + tag } -func BaseImageConfigurationExists(cudaVersion, pythonVersion, torchVersion string) bool { +func BaseImageConfigurationExists(cudaVersion, pythonVersion, torchVersion string) (bool, string, string, string) { + compatibleTorchVersion := "" for _, conf := range BaseImageConfigurations() { // Check CUDA version compatibility if !isVersionCompatible(conf.CUDAVersion, cudaVersion) { @@ -254,14 +288,22 @@ func BaseImageConfigurationExists(cudaVersion, pythonVersion, torchVersion strin continue } - return true + if compatibleTorchVersion == "" || version.Greater(conf.TorchVersion, compatibleTorchVersion) { + compatibleTorchVersion = version.StripModifier(conf.TorchVersion) + } } - return false + + valid := (torchVersion != "" && compatibleTorchVersion != "") || torchVersion == "" + if valid { + torchVersion = compatibleTorchVersion + } + + return valid, cudaVersion, pythonVersion, torchVersion } func isVersionCompatible(confVersion, requestedVersion string) bool { if confVersion == "" || requestedVersion == "" { return confVersion == requestedVersion } - return version.Matches(confVersion, requestedVersion) + return version.Matches(requestedVersion, confVersion) } diff --git a/pkg/dockerfile/base_test.go b/pkg/dockerfile/base_test.go index 25dd38bd23..3b2415b9fa 100644 --- a/pkg/dockerfile/base_test.go +++ b/pkg/dockerfile/base_test.go @@ -1,6 +1,8 @@ package dockerfile import ( + "reflect" + "strings" "testing" "github.com/stretchr/testify/require" @@ -16,15 +18,61 @@ func TestBaseImageName(t *testing.T) { {"", "3.8", "", "r8.im/cog-base:python3.8"}, {"", "3.8", "2.1", - "r8.im/cog-base:python3.8-torch2.1"}, + "r8.im/cog-base:python3.8-torch2.1.2"}, {"12.1", "3.8", "", "r8.im/cog-base:cuda12.1-python3.8"}, {"12.1", "3.8", "2.1", - "r8.im/cog-base:cuda12.1-python3.8-torch2.1"}, + "r8.im/cog-base:cuda12.1-python3.8-torch2.1.2"}, {"12.1", "3.8", "2.1", - "r8.im/cog-base:cuda12.1-python3.8-torch2.1"}, + "r8.im/cog-base:cuda12.1-python3.8-torch2.1.2"}, } { actual := BaseImageName(tt.cuda, tt.python, tt.torch) require.Equal(t, tt.expected, actual) } } + +func TestBaseImageNameWithVersionModifier(t *testing.T) { + actual := BaseImageName("12.1", "3.8", "2.0.1+cu118") + require.Equal(t, "r8.im/cog-base:cuda12.1-python3.8-torch2.0.1", actual) +} + +func TestGenerateDockerfile(t *testing.T) { + generator, err := NewBaseImageGenerator( + "12.1", + "3.8", + "2.1.0", + ) + require.NoError(t, err) + dockerfile, err := generator.GenerateDockerfile() + require.NoError(t, err) + require.True(t, strings.Contains(dockerfile, "FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04")) +} + +func TestBaseImageConfigurationExists(t *testing.T) { + exists, _, _, torchVersion := BaseImageConfigurationExists("12.1", "3.9", "2.3") + require.True(t, exists) + require.Equal(t, "2.3.1", torchVersion) +} + +func TestBaseImageConfigurationExistsNoTorch(t *testing.T) { + exists, _, _, _ := BaseImageConfigurationExists("", "3.12", "") + require.True(t, exists) +} + +func TestBaseImageConfigurationExistsNoCUDA(t *testing.T) { + exists, _, _, torchVersion := BaseImageConfigurationExists("", "3.8", "2.1") + require.True(t, exists) + require.Equal(t, "2.1.2", torchVersion) +} + +func TestIsVersionCompatible(t *testing.T) { + compatible := isVersionCompatible("2.3.1+cu121", "2.3") + require.True(t, compatible) +} + +func TestPythonPackages(t *testing.T) { + generator, err := NewBaseImageGenerator("12.1", "3.9", "2.1.0") + require.NoError(t, err) + pkgs := generator.pythonPackages() + require.Truef(t, reflect.DeepEqual(pkgs, []string{"torch==2.1.0", "torchvision==0.16.0", "torchaudio==2.1.0"}), "expected %v", pkgs) +} diff --git a/pkg/dockerfile/generator.go b/pkg/dockerfile/generator.go index ee17620fa4..ec5e9cf52b 100644 --- a/pkg/dockerfile/generator.go +++ b/pkg/dockerfile/generator.go @@ -401,6 +401,9 @@ func (g *Generator) pipInstalls() (string, error) { if torchvisionVersion, ok := g.Config.TorchvisionVersion(); ok { excludePackages = append(excludePackages, "torchvision=="+torchvisionVersion) } + if torchaudioVersion, ok := g.Config.TorchaudioVersion(); ok { + excludePackages = append(excludePackages, "torchaudio=="+torchaudioVersion) + } g.pythonRequirementsContents, err = g.Config.PythonRequirementsForArch(g.GOOS, g.GOARCH, excludePackages) if err != nil { return "", err @@ -594,13 +597,6 @@ func (g *Generator) determineBaseImageName() (string, error) { } torchVersion, _ := g.Config.TorchVersion() - torchVersion, changed, err = stripPatchVersion(torchVersion) - if err != nil { - return "", err - } - if changed { - console.Warnf("Stripping patch version from Torch version %s to %s", g.Config.Build.PythonVersion, pythonVersion) - } // validate that the base image configuration exists imageGenerator, err := NewBaseImageGenerator(cudaVersion, pythonVersion, torchVersion) diff --git a/pkg/dockerfile/generator_test.go b/pkg/dockerfile/generator_test.go index 77e73e8ff5..832be0886e 100644 --- a/pkg/dockerfile/generator_test.go +++ b/pkg/dockerfile/generator_test.go @@ -560,6 +560,7 @@ func TestGenerateFullGPUWithCogBaseImage(t *testing.T) { build: gpu: true cuda: "11.8" + python_version: "3.11" system_packages: - ffmpeg - cowsay @@ -580,17 +581,22 @@ predict: predict.py:Predictor _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") require.NoError(t, err) - expected := `#syntax=docker/dockerfile:1.4 + // We add the patch version to the expected torch version + expectedTorchVersion := torchVersion + if torchVersion == "2.3" { + expectedTorchVersion = "2.3.1" + } + expected := fmt.Sprintf(`#syntax=docker/dockerfile:1.4 FROM r8.im/replicate/cog-test-weights AS weights -FROM r8.im/cog-base:cuda11.8-python3.12-torch2.3 +FROM r8.im/cog-base:cuda11.8-python3.11-torch%s RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy cowsay && rm -rf /var/lib/apt/lists/* -COPY ` + gen.relativeTmpDir + `/requirements.txt /tmp/requirements.txt +COPY `+gen.relativeTmpDir+`/requirements.txt /tmp/requirements.txt RUN pip install -r /tmp/requirements.txt RUN cowsay moo WORKDIR /src EXPOSE 5000 CMD ["python", "-m", "cog.server.http"] -COPY . /src` +COPY . /src`, expectedTorchVersion) require.Equal(t, expected, actual) @@ -599,3 +605,49 @@ COPY . /src` require.Equal(t, "pandas==2.0.3", string(requirements)) } } + +func TestGenerateTorchWithStrippedModifiedVersion(t *testing.T) { + tmpDir := t.TempDir() + + yaml := ` +build: + gpu: true + cuda: "11.8" + system_packages: + - ffmpeg + - cowsay + python_packages: + - torch==2.3.1+cu118 + - pandas==2.0.3 + run: + - "cowsay moo" +predict: predict.py:Predictor +` + conf, err := config.FromYAML([]byte(yaml)) + require.NoError(t, err) + require.NoError(t, conf.ValidateAndComplete("")) + + gen, err := NewGenerator(conf, tmpDir) + require.NoError(t, err) + gen.SetUseCogBaseImage(true) + _, actual, _, err := gen.GenerateModelBaseWithSeparateWeights("r8.im/replicate/cog-test") + require.NoError(t, err) + + expected := `#syntax=docker/dockerfile:1.4 +FROM r8.im/replicate/cog-test-weights AS weights +FROM r8.im/cog-base:cuda11.8-python3.12-torch2.3.1 +RUN --mount=type=cache,target=/var/cache/apt,sharing=locked apt-get update -qq && apt-get install -qqy cowsay && rm -rf /var/lib/apt/lists/* +COPY ` + gen.relativeTmpDir + `/requirements.txt /tmp/requirements.txt +RUN pip install -r /tmp/requirements.txt +RUN cowsay moo +WORKDIR /src +EXPOSE 5000 +CMD ["python", "-m", "cog.server.http"] +COPY . /src` + + require.Equal(t, expected, actual) + + requirements, err := os.ReadFile(path.Join(gen.tmpDir, "requirements.txt")) + require.NoError(t, err) + require.Equal(t, "pandas==2.0.3", string(requirements)) +} diff --git a/pkg/util/version/version.go b/pkg/util/version/version.go index 7c3c6afeab..ca62150c6e 100644 --- a/pkg/util/version/version.go +++ b/pkg/util/version/version.go @@ -9,7 +9,7 @@ import ( type Version struct { Major int Minor int - Patch int + Patch *int Metadata string } @@ -32,7 +32,9 @@ func NewVersion(s string) (version *Version, err error) { } } if len(parts) >= 3 { - version.Patch, err = strconv.Atoi(parts[2]) + patch, err := strconv.Atoi(parts[2]) + version.Patch = new(int) + *version.Patch = patch if err != nil { return nil, fmt.Errorf("Invalid patch version %s: %w", parts[2], err) } @@ -59,7 +61,9 @@ func (v *Version) Greater(other *Version) bool { return true case v.Major == other.Major && v.Minor > other.Minor: return true - case v.Major == other.Major && v.Minor == other.Minor && v.Patch > other.Patch: + case v.Major == other.Major && + v.Minor == other.Minor && + v.PatchVersion() > other.PatchVersion(): return true default: return false @@ -67,7 +71,10 @@ func (v *Version) Greater(other *Version) bool { } func (v *Version) Equal(other *Version) bool { - return v.Major == other.Major && v.Minor == other.Minor && v.Patch == other.Patch && v.Metadata == other.Metadata + return v.Major == other.Major && + v.Minor == other.Minor && + v.PatchVersion() == other.PatchVersion() && + v.Metadata == other.Metadata } func (v *Version) GreaterOrEqual(other *Version) bool { @@ -78,6 +85,17 @@ func (v *Version) EqualMinor(other *Version) bool { return v.Major == other.Major && v.Minor == other.Minor } +func (v *Version) HasPatch() bool { + return v.Patch != nil +} + +func (v *Version) PatchVersion() int { + if v.Patch == nil { + return 0 + } + return *v.Patch +} + func Equal(v1 string, v2 string) bool { return MustVersion(v1).Equal(MustVersion(v2)) } @@ -100,7 +118,7 @@ func (v *Version) Matches(other *Version) bool { return false case v.Minor != other.Minor: return false - case v.Patch != 0 && v.Patch != other.Patch: + case v.HasPatch() && other.HasPatch() && *v.Patch != *other.Patch: return false default: return true @@ -115,3 +133,8 @@ func StripPatch(v string) string { ver := MustVersion(v) return fmt.Sprintf("%d.%d", ver.Major, ver.Minor) } + +func StripModifier(v string) string { + modifierSplit := strings.Split(v, "+") + return modifierSplit[0] +} diff --git a/pkg/util/version/version_test.go b/pkg/util/version/version_test.go index af83caafac..e3d7f96e51 100644 --- a/pkg/util/version/version_test.go +++ b/pkg/util/version/version_test.go @@ -60,3 +60,22 @@ func TestVersionGreater(t *testing.T) { require.Equal(t, tt.greater, Greater(tt.v1, tt.v2), "%s is %sgreater than %s", tt.v1, not, tt.v2) } } + +func TestVersionStripModifier(t *testing.T) { + version := "2.3.1" + versionWithModifier := version + "+cu118" + versionWithoutModifier := StripModifier(versionWithModifier) + require.Equal(t, versionWithoutModifier, version) +} + +func TestVersionMatches(t *testing.T) { + version := "2.3" + matchVersion := "2.3.2" + require.True(t, Matches(version, matchVersion)) +} + +func TestVersionMatchesModifier(t *testing.T) { + version := "2.3" + matchVersion := "2.3.2+cu118" + require.True(t, Matches(version, matchVersion)) +}