Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add selfupdate #51

Merged
merged 9 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-22.04]
go: ["1.22.7", "1.23.1"]
go: ["1.22.7", "1.23.3"]
goos: [linux]
goarch: [amd64, arm64]

Expand Down
1 change: 1 addition & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ func init() {
rootCommand.AddCommand(versionCommand)
rootCommand.AddCommand(configureCmd)
rootCommand.AddCommand(newDiagnosticsCommand())
rootCommand.AddCommand(newSelfupdateCommand())
}

func isDockerSnap() bool {
Expand Down
249 changes: 249 additions & 0 deletions cmd/selfupdate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
package cmd

import (
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"
"time"

"github.com/pelican-dev/wings/system"
"github.com/spf13/cobra"
)

var updateArgs struct {
repoOwner string
repoName string
force bool
}

func newSelfupdateCommand() *cobra.Command {
command := &cobra.Command{
Use: "update",
Short: "Update wings to the latest version",
Run: selfupdateCmdRun,
}

command.Flags().StringVar(&updateArgs.repoOwner, "repo-owner", "pelican-dev", "GitHub repository owner")
command.Flags().StringVar(&updateArgs.repoName, "repo-name", "wings", "GitHub repository name")
command.Flags().BoolVar(&updateArgs.force, "force", false, "Force update even if on latest version")

return command
}

func selfupdateCmdRun(_ *cobra.Command, _ []string) {
currentVersion := system.Version
if currentVersion == "" {
fmt.Println("Error: current version is not defined")
return
}

if currentVersion == "develop" && !updateArgs.force {
fmt.Println("Running in development mode. Use --force to override.")
return
}

fmt.Printf("Current version: %s\n", currentVersion)

// Fetch the latest release tag from GitHub API
latestVersionTag, err := fetchLatestGitHubRelease()
if err != nil {
fmt.Printf("Failed to fetch latest version: %v\n", err)
return
}

currentVersionTag := "v" + currentVersion
if currentVersion == "develop" {
currentVersionTag = currentVersion
}

if latestVersionTag == currentVersionTag && !updateArgs.force {
fmt.Printf("You are running the latest version: %s\n", currentVersion)
return
}

binaryName := determineBinaryName()
if binaryName == "" {
fmt.Printf("Error: unsupported architecture: %s\n", runtime.GOARCH)
return
}

fmt.Printf("Updating from %s to %s\n", currentVersionTag, latestVersionTag)

if err := performUpdate(latestVersionTag, binaryName); err != nil {
fmt.Printf("Update failed: %v\n", err)
return
}

fmt.Println("\nUpdate successful! Please restart the wings service (e.g., systemctl restart wings)")
}

func performUpdate(version, binaryName string) error {
downloadURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/%s",
updateArgs.repoOwner, updateArgs.repoName, version, binaryName)
checksumURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/checksums.txt",
updateArgs.repoOwner, updateArgs.repoName, version)

tmpDir, err := os.MkdirTemp("", "wings-update-*")
if err != nil {
return fmt.Errorf("failed to create temp directory: %v", err)
}
defer os.RemoveAll(tmpDir)

checksumPath := filepath.Join(tmpDir, "checksums.txt")
if err := downloadWithProgress(checksumURL, checksumPath); err != nil {
return fmt.Errorf("failed to download checksums: %v", err)
}

binaryPath := filepath.Join(tmpDir, binaryName)
if err := downloadWithProgress(downloadURL, binaryPath); err != nil {
return fmt.Errorf("failed to download binary: %v", err)
}

if err := verifyChecksum(binaryPath, checksumPath, binaryName); err != nil {
return fmt.Errorf("checksum verification failed: %v", err)
}

if err := os.Chmod(binaryPath, 0755); err != nil {
return fmt.Errorf("failed to set executable permissions: %v", err)
}

currentExecutable, err := os.Executable()
if err != nil {
return fmt.Errorf("failed to locate current executable: %v", err)
}

return os.Rename(binaryPath, currentExecutable)
}

func downloadWithProgress(url, dest string) error {
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status: %s", resp.Status)
}

out, err := os.Create(dest)
if err != nil {
return err
}
defer out.Close()

filename := filepath.Base(dest)
fmt.Printf("Downloading %s (%.2f MB)...\n", filename, float64(resp.ContentLength)/1024/1024)

pw := &progressWriter{
Writer: out,
Total: resp.ContentLength,
StartTime: time.Now(),
}

_, err = io.Copy(pw, resp.Body)
fmt.Println()
return err
}

func fetchLatestGitHubRelease() (string, error) {
apiURL := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", updateArgs.repoOwner, updateArgs.repoName)

resp, err := http.Get(apiURL)
if err != nil {
return "", err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

var releaseData struct {
TagName string `json:"tag_name"`
}
if err := json.NewDecoder(resp.Body).Decode(&releaseData); err != nil {
return "", err
}

return releaseData.TagName, nil
}

func determineBinaryName() string {
switch runtime.GOARCH {
case "amd64":
return "wings_linux_amd64"
case "arm64":
return "wings_linux_arm64"
default:
return ""
}
}

func verifyChecksum(binaryPath, checksumPath, binaryName string) error {
checksumData, err := os.ReadFile(checksumPath)
if err != nil {
return err
}

var expectedChecksum string
for _, line := range strings.Split(string(checksumData), "\n") {
if strings.HasSuffix(line, binaryName) {
parts := strings.Fields(line)
if len(parts) > 0 {
expectedChecksum = parts[0]
}
break
}
}
if expectedChecksum == "" {
return fmt.Errorf("checksum not found for %s", binaryName)
}

file, err := os.Open(binaryPath)
if err != nil {
return err
}
defer file.Close()

hasher := sha256.New()
if _, err := io.Copy(hasher, file); err != nil {
return err
}
actualChecksum := fmt.Sprintf("%x", hasher.Sum(nil))

if actualChecksum == expectedChecksum {
fmt.Printf("Checksum verification successful!\n")
}

if actualChecksum != expectedChecksum {
return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedChecksum, actualChecksum)
}

return nil
}

type progressWriter struct {
io.Writer
Total int64
Written int64
StartTime time.Time
}

func (pw *progressWriter) Write(p []byte) (int, error) {
n, err := pw.Writer.Write(p)
pw.Written += int64(n)

if pw.Total > 0 {
percent := float64(pw.Written) / float64(pw.Total) * 100
fmt.Printf("\rProgress: %.2f%%", percent)
}

return n, err
}
Loading