Skip to content

Commit

Permalink
add selfupdate (#51)
Browse files Browse the repository at this point in the history
* add selfupdate

* Bump golang 1.23 sub version

* Remove systemd restart from autoupdate

* we don't need os exec anymore

* cleanup

* Simplify version logic

* Update: Checksum verification successful message
  • Loading branch information
QuintenQVD0 authored Jan 5, 2025
1 parent cb6f528 commit ad292b6
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-22.04]
go: ["1.22.7", "1.23.1"]
go: ["1.22.7", "1.23.3"]
goos: [linux]
goarch: [amd64, arm64]

Expand Down
1 change: 1 addition & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ func init() {
rootCommand.AddCommand(versionCommand)
rootCommand.AddCommand(configureCmd)
rootCommand.AddCommand(newDiagnosticsCommand())
rootCommand.AddCommand(newSelfupdateCommand())
}

func isDockerSnap() bool {
Expand Down
249 changes: 249 additions & 0 deletions cmd/selfupdate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
package cmd

import (
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"
"time"

"github.com/pelican-dev/wings/system"
"github.com/spf13/cobra"
)

var updateArgs struct {
repoOwner string
repoName string
force bool
}

func newSelfupdateCommand() *cobra.Command {
command := &cobra.Command{
Use: "update",
Short: "Update wings to the latest version",
Run: selfupdateCmdRun,
}

command.Flags().StringVar(&updateArgs.repoOwner, "repo-owner", "pelican-dev", "GitHub repository owner")
command.Flags().StringVar(&updateArgs.repoName, "repo-name", "wings", "GitHub repository name")
command.Flags().BoolVar(&updateArgs.force, "force", false, "Force update even if on latest version")

return command
}

func selfupdateCmdRun(_ *cobra.Command, _ []string) {
currentVersion := system.Version
if currentVersion == "" {
fmt.Println("Error: current version is not defined")
return
}

if currentVersion == "develop" && !updateArgs.force {
fmt.Println("Running in development mode. Use --force to override.")
return
}

fmt.Printf("Current version: %s\n", currentVersion)

// Fetch the latest release tag from GitHub API
latestVersionTag, err := fetchLatestGitHubRelease()
if err != nil {
fmt.Printf("Failed to fetch latest version: %v\n", err)
return
}

currentVersionTag := "v" + currentVersion
if currentVersion == "develop" {
currentVersionTag = currentVersion
}

if latestVersionTag == currentVersionTag && !updateArgs.force {
fmt.Printf("You are running the latest version: %s\n", currentVersion)
return
}

binaryName := determineBinaryName()
if binaryName == "" {
fmt.Printf("Error: unsupported architecture: %s\n", runtime.GOARCH)
return
}

fmt.Printf("Updating from %s to %s\n", currentVersionTag, latestVersionTag)

if err := performUpdate(latestVersionTag, binaryName); err != nil {
fmt.Printf("Update failed: %v\n", err)
return
}

fmt.Println("\nUpdate successful! Please restart the wings service (e.g., systemctl restart wings)")
}

func performUpdate(version, binaryName string) error {
downloadURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/%s",
updateArgs.repoOwner, updateArgs.repoName, version, binaryName)
checksumURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/checksums.txt",
updateArgs.repoOwner, updateArgs.repoName, version)

tmpDir, err := os.MkdirTemp("", "wings-update-*")
if err != nil {
return fmt.Errorf("failed to create temp directory: %v", err)
}
defer os.RemoveAll(tmpDir)

checksumPath := filepath.Join(tmpDir, "checksums.txt")
if err := downloadWithProgress(checksumURL, checksumPath); err != nil {
return fmt.Errorf("failed to download checksums: %v", err)
}

binaryPath := filepath.Join(tmpDir, binaryName)
if err := downloadWithProgress(downloadURL, binaryPath); err != nil {
return fmt.Errorf("failed to download binary: %v", err)
}

if err := verifyChecksum(binaryPath, checksumPath, binaryName); err != nil {
return fmt.Errorf("checksum verification failed: %v", err)
}

if err := os.Chmod(binaryPath, 0755); err != nil {
return fmt.Errorf("failed to set executable permissions: %v", err)
}

currentExecutable, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to locate current executable: %v", err)
}

return os.Rename(binaryPath, currentExecutable)
}

func downloadWithProgress(url, dest string) error {
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status: %s", resp.Status)
}

out, err := os.Create(dest)
if err != nil {
return err
}
defer out.Close()

filename := filepath.Base(dest)
fmt.Printf("Downloading %s (%.2f MB)...\n", filename, float64(resp.ContentLength)/1024/1024)

pw := &progressWriter{
Writer: out,
Total: resp.ContentLength,
StartTime: time.Now(),
}

_, err = io.Copy(pw, resp.Body)
fmt.Println()
return err
}

func fetchLatestGitHubRelease() (string, error) {
apiURL := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", updateArgs.repoOwner, updateArgs.repoName)

resp, err := http.Get(apiURL)
if err != nil {
return "", err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

var releaseData struct {
TagName string `json:"tag_name"`
}
if err := json.NewDecoder(resp.Body).Decode(&releaseData); err != nil {
return "", err
}

return releaseData.TagName, nil
}

func determineBinaryName() string {
switch runtime.GOARCH {
case "amd64":
return "wings_linux_amd64"
case "arm64":
return "wings_linux_arm64"
default:
return ""
}
}

func verifyChecksum(binaryPath, checksumPath, binaryName string) error {
checksumData, err := os.ReadFile(checksumPath)
if err != nil {
return err
}

var expectedChecksum string
for _, line := range strings.Split(string(checksumData), "\n") {
if strings.HasSuffix(line, binaryName) {
parts := strings.Fields(line)
if len(parts) > 0 {
expectedChecksum = parts[0]
}
break
}
}
if expectedChecksum == "" {
return fmt.Errorf("checksum not found for %s", binaryName)
}

file, err := os.Open(binaryPath)
if err != nil {
return err
}
defer file.Close()

hasher := sha256.New()
if _, err := io.Copy(hasher, file); err != nil {
return err
}
actualChecksum := fmt.Sprintf("%x", hasher.Sum(nil))

if actualChecksum == expectedChecksum {
fmt.Printf("Checksum verification successful!\n")
}

if actualChecksum != expectedChecksum {
return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedChecksum, actualChecksum)
}

return nil
}

type progressWriter struct {
io.Writer
Total int64
Written int64
StartTime time.Time
}

func (pw *progressWriter) Write(p []byte) (int, error) {
n, err := pw.Writer.Write(p)
pw.Written += int64(n)

if pw.Total > 0 {
percent := float64(pw.Written) / float64(pw.Total) * 100
fmt.Printf("\rProgress: %.2f%%", percent)
}

return n, err
}

0 comments on commit ad292b6

Please sign in to comment.