diff --git a/pkg/handler/update.go b/pkg/handler/update.go index 9e2491b..1a4ab2f 100644 --- a/pkg/handler/update.go +++ b/pkg/handler/update.go @@ -12,8 +12,6 @@ import ( "path/filepath" "regexp" "runtime" - "strconv" - "strings" "time" ) @@ -40,31 +38,6 @@ func NewUpdateHandler() *UpdateHandler { } } -// compareVersion compares strings like "v1.1.1", returns true when versionA > versionB -func (u *UpdateHandler) compareVersion(versionA, versionB string) bool { - if versionA == "" { - return false - } - if versionB == "" { - return true - } - aList := strings.Split(versionA[1:], ".") - bList := strings.Split(versionB[1:], ".") - aLen := len(aList) - bLen := len(bList) - if aLen != bLen { - return aLen > bLen - } - for i := 0; i < aLen; i++ { - a, _ := strconv.Atoi(aList[i]) - b, _ := strconv.Atoi(bList[i]) - if a > b { - return true - } - } - return false -} - // CheckLatestVersion returns true when there is a newer version func (u *UpdateHandler) CheckLatestVersion() (bool, error) { resp, err := u.client.Get(fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", ipgw.Repo)) @@ -73,7 +46,7 @@ func (u *UpdateHandler) CheckLatestVersion() (bool, error) { } body := utils.ReadBody(resp) latestVersion, _ := utils.MatchSingle(regexp.MustCompile(`"tag_name": *"(.+?)"`), body) - return u.compareVersion(latestVersion, ipgw.Version), nil + return utils.CompareVersion(utils.ParseVersion(latestVersion), utils.ParseVersion(ipgw.Version)), nil } // download returns downloaded path diff --git a/pkg/utils/semver.go b/pkg/utils/semver.go new file mode 100644 index 0000000..c554ce9 --- /dev/null +++ b/pkg/utils/semver.go @@ -0,0 +1,84 @@ +package utils + +import ( + "strconv" + "strings" +) + +type Semver struct { + Major int + Minor int + Patch int + Prerelease string +} + +// CompareVersion compares strings like "v1.1.1", returns true when versionA > versionB +func CompareVersion(versionA, versionB *Semver) bool { + if versionA == nil { + return false + } + if versionB == nil { + return true + } + // semver: + // v1.0.1 > v1.0.0 + // v1.0.0 > v1.0.0-beta + // v1.0.0-beta.1 > v1.0.0-beta + // v1.0.0-beta > v1.0.0-alpha + if versionA.Major != versionB.Major { + return versionA.Major > versionB.Major + } + + if versionA.Minor != versionB.Minor { + return versionA.Minor > versionB.Minor + } + + if versionA.Patch != versionB.Patch { + return versionA.Patch > versionB.Patch + } + + if versionA.Prerelease != "" && versionB.Prerelease != "" { + return versionA.Prerelease > versionB.Prerelease + } + + if versionA.Prerelease != "" { + // versionB.Prerelease == "" + return false + } + + if versionB.Prerelease != "" { + // versionA.Prerelease == "" + return true + } + + return false +} + +func ParseVersion(version string) *Semver { + if version == "" { + return nil + } + // remove leading 'v' + version = strings.TrimPrefix(version, "v") + parts := strings.Split(version, "-") + if len(parts) < 1 { + return nil + } + dot := strings.Split(parts[0], ".") + if len(dot) < 3 { + return nil + } + preRelease := "" + if len(parts) > 1 { + preRelease = parts[1] + } + major, _ := strconv.Atoi(dot[0]) + minor, _ := strconv.Atoi(dot[1]) + patch, _ := strconv.Atoi(dot[2]) + return &Semver{ + Major: major, + Minor: minor, + Patch: patch, + Prerelease: preRelease, + } +} diff --git a/pkg/utils/semver_test.go b/pkg/utils/semver_test.go new file mode 100644 index 0000000..159f902 --- /dev/null +++ b/pkg/utils/semver_test.go @@ -0,0 +1,57 @@ +package utils + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestParseVersion(t *testing.T) { + a := assert.New(t) + cases := []struct { + v string + parsed *Semver + }{ + {v: "v1.0.1", parsed: &Semver{ + Major: 1, + Minor: 0, + Patch: 1, + Prerelease: "", + }}, + {v: "v1.0.1-beta", parsed: &Semver{ + Major: 1, + Minor: 0, + Patch: 1, + Prerelease: "beta", + }}, + {v: "0.2.0-alpha.3", parsed: &Semver{ + Major: 0, + Minor: 2, + Patch: 0, + Prerelease: "alpha.3", + }}, + {v: "0.1", parsed: nil}, + } + for _, c := range cases { + a.Equal(ParseVersion(c.v), c.parsed) + } +} + +func TestCompareVersion(t *testing.T) { + a := assert.New(t) + cases := []struct { + a string + b string + expected bool + }{ + {a: "v1.0.0", b: "v0.2.2", expected: true}, + {a: "v1.0.0-alpha", b: "v1.2.2", expected: false}, + {a: "v1.0.0-alpha", b: "v1.0.0", expected: false}, + {a: "v1.0.0", b: "v1.0.0-beta.2", expected: true}, + {a: "v1.0.0-beta", b: "v1.0.0-alpha", expected: true}, + {a: "v1", b: "v1.0.0-alpha", expected: false}, + {a: "1.0.0-beta", b: "", expected: true}, + } + for _, c := range cases { + a.Equal(CompareVersion(ParseVersion(c.a), ParseVersion(c.b)), c.expected) + } +}