Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add selfupdate #51

Merged
merged 9 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-22.04]
go: ["1.22.7", "1.23.1"]
go: ["1.22.7", "1.23.3"]
goos: [linux]
goarch: [amd64, arm64]

Expand Down
23 changes: 12 additions & 11 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,24 @@ func init() {
rootCommand.AddCommand(versionCommand)
rootCommand.AddCommand(configureCmd)
rootCommand.AddCommand(newDiagnosticsCommand())
rootCommand.AddCommand(newSelfupdateCommand())
}

func isDockerSnap() bool {
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
log.Fatalf("Unable to initialize Docker client: %s", err)
}
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
log.Fatalf("Unable to initialize Docker client: %s", err)
}

defer cli.Close() // Close the client when the function returns (should not be needed, but just to be safe)
defer cli.Close() // Close the client when the function returns (should not be needed, but just to be safe)

info, err := cli.Info(context.Background())
if err != nil {
log.Fatalf("Unable to get Docker info: %s", err)
}
info, err := cli.Info(context.Background())
if err != nil {
log.Fatalf("Unable to get Docker info: %s", err)
}

// Check if Docker root directory contains '/var/snap/docker'
return strings.Contains(info.DockerRootDir, "/var/snap/docker")
// Check if Docker root directory contains '/var/snap/docker'
return strings.Contains(info.DockerRootDir, "/var/snap/docker")
}

func rootCmdRun(cmd *cobra.Command, _ []string) {
Expand Down
257 changes: 257 additions & 0 deletions cmd/selfupdate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
package cmd

import (
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"os/exec"
"runtime"
"strings"
"time"

"github.com/pelican-dev/wings/system"
"github.com/spf13/cobra"
)

var updateArgs struct {
repoOwner string
repoName string
}

func newSelfupdateCommand() *cobra.Command {
command := &cobra.Command{
Use: "update",
Short: "Update the wings to the latest version",
Run: selfupdateCmdRun,
}

command.Flags().StringVar(&updateArgs.repoOwner, "repo-owner", "pelican-dev", "GitHub username or organization that owns the repository containing the updates")
command.Flags().StringVar(&updateArgs.repoName, "repo-name", "wings", "The name of the GitHub repository to fetch updates from")

return command
}

func selfupdateCmdRun(*cobra.Command, []string) {
currentVersion := system.Version
if currentVersion == "" {
fmt.Println("Error: Current version is not defined")
return
}

if currentVersion == "develop" {
fmt.Println("Running in development mode. Skipping update.")
return
}

fmt.Println("Current version:", currentVersion)

// Fetch the latest release tag from GitHub API
latestVersionTag, err := fetchLatestGitHubRelease()
if err != nil {
fmt.Println("Failed to fetch the latest version:", err)
return
}

currentVersionTag := "v" + currentVersion
if latestVersionTag == currentVersionTag {
fmt.Println("You are running the latest version:", currentVersion)
return
}

fmt.Printf("A new version is available: %s (current: %s)\n", latestVersionTag, currentVersionTag)

binaryName := determineBinaryName()
if binaryName == "" {
fmt.Println("Unsupported architecture")
return
}

downloadURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/%s", updateArgs.repoOwner, updateArgs.repoName, latestVersionTag, binaryName)
checksumURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/checksums.txt", updateArgs.repoOwner, updateArgs.repoName, latestVersionTag)

fmt.Println("Downloading checksums.txt...")
checksumFile, err := downloadFile(checksumURL, "checksums.txt")
if err != nil {
fmt.Println("Failed to download checksum file:", err)
return
}
defer os.Remove(checksumFile)

fmt.Println("Downloading", binaryName, "...")
binaryFile, err := downloadFile(downloadURL, binaryName)
if err != nil {
fmt.Println("Failed to download binary file:", err)
return
}
defer os.Remove(binaryFile)

if err := verifyChecksum(binaryFile, checksumFile, binaryName); err != nil {
fmt.Println("Checksum verification failed:", err)
return
}
fmt.Println("\nChecksum verification successful.")

currentExecutable, err := os.Executable()
if err != nil {
fmt.Println("Failed to locate current executable:", err)
return
}

if err := os.Chmod(binaryFile, 0755); err != nil {
fmt.Println("Failed to set executable permissions on the new binary:", err)
return
}

if err := replaceBinary(currentExecutable, binaryFile); err != nil {
fmt.Println("Failed to replace executable:", err)
return
}

fmt.Println("Restarting service...")

if err := restartService(); err != nil {
fmt.Println("Error restarting the wings service:", err)
} else {
fmt.Println("Service restarted successfully.")
}
}

func fetchLatestGitHubRelease() (string, error) {
apiURL := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", updateArgs.repoOwner, updateArgs.repoName)

resp, err := http.Get(apiURL)
if err != nil {
return "", err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

var releaseData struct {
TagName string `json:"tag_name"`
}
if err := json.NewDecoder(resp.Body).Decode(&releaseData); err != nil {
return "", err
}

return releaseData.TagName, nil
}

func determineBinaryName() string {
switch runtime.GOARCH {
case "amd64":
return "wings_linux_amd64"
case "arm64":
return "wings_linux_arm64"
default:
return ""
}
}

func downloadFile(url, fileName string) (string, error) {
tmpFile, err := os.CreateTemp("", fileName)
if err != nil {
return "", err
}
defer tmpFile.Close()

resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("unexpected status: %s", resp.Status)
}

fmt.Printf("Downloading %s (%.2f MB)...\n", fileName, float64(resp.ContentLength)/1024/1024)
progressWriter := &progressWriter{Writer: tmpFile, Total: resp.ContentLength}
if _, err := io.Copy(progressWriter, resp.Body); err != nil {
return "", err
}

fmt.Println() // Ensure a newline after download progress
return tmpFile.Name(), nil
}

func verifyChecksum(binaryPath, checksumPath, binaryName string) error {
checksumData, err := os.ReadFile(checksumPath)
if err != nil {
return err
}

var expectedChecksum string
for _, line := range strings.Split(string(checksumData), "\n") {
if strings.HasSuffix(line, binaryName) {
parts := strings.Fields(line)
if len(parts) > 0 {
expectedChecksum = parts[0]
}
break
}
}
if expectedChecksum == "" {
return fmt.Errorf("checksum not found for %s", binaryName)
}

file, err := os.Open(binaryPath)
if err != nil {
return err
}
defer file.Close()

hasher := sha256.New()
if _, err := io.Copy(hasher, file); err != nil {
return err
}
actualChecksum := fmt.Sprintf("%x", hasher.Sum(nil))

if actualChecksum != expectedChecksum {
return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedChecksum, actualChecksum)
}

return nil
}

func replaceBinary(currentPath, newPath string) error {
return os.Rename(newPath, currentPath)
}

type progressWriter struct {
io.Writer
Total int64
Written int64
StartTime time.Time
}

func (pw *progressWriter) Write(p []byte) (int, error) {
n, err := pw.Writer.Write(p)
pw.Written += int64(n)

if pw.Total > 0 {
percent := float64(pw.Written) / float64(pw.Total) * 100
fmt.Printf("\rProgress: %.2f%%", percent)
}

return n, err
}

func restartService() error {
// Try to run the systemctl restart command
cmd := exec.Command("systemctl", "restart", "wings")
QuintenQVD0 marked this conversation as resolved.
Show resolved Hide resolved
cmdOutput, err := cmd.CombinedOutput()

if err != nil {
// If systemctl command fails, return the error with output
return fmt.Errorf("failed to restart service: %s\n%s", err.Error(), string(cmdOutput))
}

// If successful, return nil
return nil
}
Loading