Skip to content

Commit

Permalink
feat(golang-rewrite): implement asdf set command (#1829)
Browse files Browse the repository at this point in the history
  • Loading branch information
Stratus3D authored Dec 31, 2024
1 parent b18a46f commit f68b29b
Show file tree
Hide file tree
Showing 6 changed files with 418 additions and 18 deletions.
24 changes: 24 additions & 0 deletions internal/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"strings"
"text/tabwriter"

"github.com/asdf-vm/asdf/internal/cli/set"
"github.com/asdf-vm/asdf/internal/completions"
"github.com/asdf-vm/asdf/internal/config"
"github.com/asdf-vm/asdf/internal/exec"
Expand Down Expand Up @@ -263,6 +264,29 @@ func Execute(version string) {
return reshimCommand(logger, args.Get(0), args.Get(1))
},
},
{
Name: "set",
Flags: []cli.Flag{
&cli.BoolFlag{
Name: "home",
Aliases: []string{"u"},
Usage: "The version should be set in the current users home directory",
},
&cli.BoolFlag{
Name: "parent",
Aliases: []string{"p"},
Usage: "The version should be set in the closest existing .tool-versions file in a parent directory",
},
},
Action: func(cCtx *cli.Context) error {
args := cCtx.Args().Slice()
home := cCtx.Bool("home")
parent := cCtx.Bool("parent")
return set.Main(os.Stdout, os.Stderr, args, home, parent, func() (string, error) {
return os.UserHomeDir()
})
},
},
{
Name: "shimversions",
Action: func(cCtx *cli.Context) error {
Expand Down
105 changes: 105 additions & 0 deletions internal/cli/set/set.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Package set provides the 'asdf set' command
package set

import (
"errors"
"fmt"
"io"
"os"
"path/filepath"

"github.com/asdf-vm/asdf/internal/config"
"github.com/asdf-vm/asdf/internal/plugins"
"github.com/asdf-vm/asdf/internal/toolversions"
"github.com/asdf-vm/asdf/internal/versions"
)

// Main function is the entrypoint for the 'asdf set' command
func Main(_ io.Writer, stderr io.Writer, args []string, home bool, parent bool, homeFunc func() (string, error)) error {
if len(args) < 1 {
return printError(stderr, "tool and version must be provided as arguments")
}

if len(args) < 2 {
return printError(stderr, "version must be provided as an argument")
}

if home && parent {
return printError(stderr, "home and parent flags cannot both be specified; must be one location or the other")
}

conf, err := config.LoadConfig()
if err != nil {
return printError(stderr, fmt.Sprintf("error loading config: %s", err))
}

resolvedVersions := []string{}

for _, version := range args[1:] {
parsedVersion := toolversions.ParseFromCliArg(version)
if parsedVersion.Type == "latest" {
plugin := plugins.New(conf, args[0])
resolvedVersion, err := versions.Latest(plugin, parsedVersion.Value)
if err != nil {
return fmt.Errorf("unable to resolve latest version for %s", plugin.Name)
}
resolvedVersions = append(resolvedVersions, resolvedVersion)
continue
}
resolvedVersions = append(resolvedVersions, version)
}

tv := toolversions.ToolVersions{Name: args[0], Versions: resolvedVersions}

if home {
homeDir, err := homeFunc()
if err != nil {
return err
}

filepath := filepath.Join(homeDir, conf.DefaultToolVersionsFilename)
return toolversions.WriteToolVersionsToFile(filepath, []toolversions.ToolVersions{tv})
}

currentDir, err := os.Getwd()
if err != nil {
printError(stderr, fmt.Sprintf("unable to get current directory: %s", err))
return err
}

if parent {
// locate file in parent dir and update it
path, found := findVersionFileInParentDir(conf, currentDir)
if !found {
return printError(stderr, fmt.Sprintf("No %s version file found in parent directory", conf.DefaultToolVersionsFilename))
}

return toolversions.WriteToolVersionsToFile(path, []toolversions.ToolVersions{tv})
}

// Write new file in current dir
filepath := filepath.Join(currentDir, conf.DefaultToolVersionsFilename)
return toolversions.WriteToolVersionsToFile(filepath, []toolversions.ToolVersions{tv})
}

func printError(stderr io.Writer, msg string) error {
fmt.Fprintf(stderr, "%s", msg)
return errors.New(msg)
}

func findVersionFileInParentDir(conf config.Config, directory string) (string, bool) {
directory = filepath.Dir(directory)

for {
path := filepath.Join(directory, conf.DefaultToolVersionsFilename)
if _, err := os.Stat(path); err == nil {
return path, true
}

if directory == "/" {
return "", false
}

directory = filepath.Dir(directory)
}
}
105 changes: 105 additions & 0 deletions internal/cli/set/set_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package set

import (
"os"
"path/filepath"
"strings"
"testing"

"github.com/stretchr/testify/assert"
)

func TestAll(t *testing.T) {
homeFunc := func() (string, error) {
return "", nil
}

t.Run("prints error when no arguments specified", func(t *testing.T) {
stdout, stderr := buildOutputs()
err := Main(&stdout, &stderr, []string{}, false, false, homeFunc)

assert.Error(t, err, "tool and version must be provided as arguments")
assert.Equal(t, stdout.String(), "")
assert.Equal(t, stderr.String(), "tool and version must be provided as arguments")
})

t.Run("prints error when no version specified", func(t *testing.T) {
stdout, stderr := buildOutputs()
err := Main(&stdout, &stderr, []string{"lua"}, false, false, homeFunc)

assert.Error(t, err, "version must be provided as an argument")
assert.Equal(t, stdout.String(), "")
assert.Equal(t, stderr.String(), "version must be provided as an argument")
})

t.Run("prints error when both --parent and --home flags are set", func(t *testing.T) {
stdout, stderr := buildOutputs()
err := Main(&stdout, &stderr, []string{"lua", "5.2.3"}, true, true, homeFunc)

assert.Error(t, err, "home and parent flags cannot both be specified; must be one location or the other")
assert.Equal(t, stdout.String(), "")
assert.Equal(t, stderr.String(), "home and parent flags cannot both be specified; must be one location or the other")
})

t.Run("sets version in current directory when no flags provided", func(t *testing.T) {
stdout, stderr := buildOutputs()
dir := t.TempDir()
assert.Nil(t, os.Chdir(dir))

err := Main(&stdout, &stderr, []string{"lua", "5.2.3"}, false, false, homeFunc)

assert.Nil(t, err)
assert.Equal(t, stdout.String(), "")
assert.Equal(t, stderr.String(), "")

path := filepath.Join(dir, ".tool-versions")
bytes, err := os.ReadFile(path)
assert.Nil(t, err)
assert.Equal(t, "lua 5.2.3\n", string(bytes))
})

t.Run("sets version in parent directory when --parent flag provided", func(t *testing.T) {
stdout, stderr := buildOutputs()
dir := t.TempDir()
subdir := filepath.Join(dir, "subdir")
assert.Nil(t, os.Mkdir(subdir, 0o777))
assert.Nil(t, os.Chdir(subdir))
assert.Nil(t, os.WriteFile(filepath.Join(dir, ".tool-versions"), []byte("lua 4.0.0"), 0o666))

err := Main(&stdout, &stderr, []string{"lua", "5.2.3"}, false, true, homeFunc)

assert.Nil(t, err)
assert.Equal(t, stdout.String(), "")
assert.Equal(t, stderr.String(), "")

path := filepath.Join(dir, ".tool-versions")
bytes, err := os.ReadFile(path)
assert.Nil(t, err)
assert.Equal(t, "lua 5.2.3\n", string(bytes))
})

t.Run("sets version in home directory when --home flag provided", func(t *testing.T) {
stdout, stderr := buildOutputs()
homedir := filepath.Join(t.TempDir(), "home")
assert.Nil(t, os.Mkdir(homedir, 0o777))
err := Main(&stdout, &stderr, []string{"lua", "5.2.3"}, true, false, func() (string, error) {
return homedir, nil
})

assert.Nil(t, err)
assert.Equal(t, stdout.String(), "")
assert.Equal(t, stderr.String(), "")

path := filepath.Join(homedir, ".tool-versions")
bytes, err := os.ReadFile(path)
assert.Nil(t, err)
assert.Equal(t, "lua 5.2.3\n", string(bytes))
})
}

func buildOutputs() (strings.Builder, strings.Builder) {
var stdout strings.Builder
var stderr strings.Builder

return stdout, stderr
}
5 changes: 3 additions & 2 deletions internal/help/help.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ asdf list <name> [version] List installed versions of a package and
optionally filter the versions
asdf list all <name> [<version>] List all versions of a package and
optionally filter the returned versions
asdf shell <name> <version> Set the package version to
`ASDF_${LANG}_VERSION` in the current shell
asdf set [-h] [-p] <name> <versions...> Set a tool version in a .tool-version in
the current directory, or a parent
directory.
asdf uninstall <name> <version> Remove a specific version of a package
asdf where <name> [<version>] Display install path for an installed
or current version
Expand Down
96 changes: 80 additions & 16 deletions internal/toolversions/toolversions.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,58 @@ type ToolVersions struct {
Versions []string
}

// WriteToolVersionsToFile takes a path to a file and writes the new tool and
// version data to the file. It creates the file if it does not exist and
// updates it if it does.
func WriteToolVersionsToFile(filepath string, toolVersions []ToolVersions) error {
content, err := os.ReadFile(filepath)
if _, ok := err.(*os.PathError); err != nil && !ok {
return err
}

updatedContent := updateContentWithToolVersions(string(content), toolVersions)
return os.WriteFile(filepath, []byte(updatedContent), 0o666)
}

func updateContentWithToolVersions(content string, toolVersions []ToolVersions) string {
var output strings.Builder

if content != "" {
for _, line := range readLines(content) {
tokens, comment := parseLine(line)
if len(tokens) > 1 {
tv := ToolVersions{Name: tokens[0], Versions: tokens[1:]}

indexMatching := slices.IndexFunc(toolVersions, func(toolVersion ToolVersions) bool {
return toolVersion.Name == tv.Name
})

if indexMatching != -1 {
// write updated version
newTv := toolVersions[indexMatching]
newTokens := toolVersionsToTokens(newTv)
fmt.Fprintf(&output, "%s\n", encodeLine(newTokens, comment))
toolVersions = slices.Delete(toolVersions, indexMatching, indexMatching+1)
continue
}
}

// write back original line
fmt.Fprintf(&output, "%s\n", line)
}
}

// If any ToolVersions structs remaining, write them to the end of the file
if len(toolVersions) > 0 {
for _, toolVersion := range toolVersions {
newTokens := toolVersionsToTokens(toolVersion)
fmt.Fprintf(&output, "%s\n", encodeLine(newTokens, ""))
}
}

return output.String()
}

// FindToolVersions looks up a tool version in a tool versions file and if found
// returns a slice of versions for it.
func FindToolVersions(filepath, toolName string) (versions []string, found bool, err error) {
Expand Down Expand Up @@ -153,17 +205,8 @@ func FormatForFS(version Version) string {
}
}

// readLines reads all the lines in a given file
// removing spaces and comments which are marked by '#'
func readLines(content string) (lines []string) {
for _, line := range strings.Split(content, "\n") {
line, _, _ = strings.Cut(line, "#")
line = strings.TrimSpace(line)
if len(line) > 0 {
lines = append(lines, line)
}
}
return
return strings.Split(content, "\n")
}

func findToolVersionsInContent(content, toolName string) (versions []string, found bool) {
Expand All @@ -179,21 +222,42 @@ func findToolVersionsInContent(content, toolName string) (versions []string, fou

func getAllToolsAndVersionsInContent(content string) (toolVersions []ToolVersions) {
for _, line := range readLines(content) {
tokens := parseLine(line)
newTool := ToolVersions{Name: tokens[0], Versions: tokens[1:]}
toolVersions = append(toolVersions, newTool)
tokens, _ := parseLine(line)
if len(tokens) > 1 {
newTool := ToolVersions{Name: tokens[0], Versions: tokens[1:]}
toolVersions = append(toolVersions, newTool)
}
}

return toolVersions
}

func parseLine(line string) (tokens []string) {
for _, token := range strings.Split(line, " ") {
// parseLine receives a single line from a file and parses it into a list of
// tokens and a comment. A comment may occur anywhere on the line and is started
// by a `#` character.
func parseLine(line string) (tokens []string, comment string) {
preComment, comment, _ := strings.Cut(line, "#")
for _, token := range strings.Split(preComment, " ") {
token = strings.TrimSpace(token)
if len(token) > 0 {
tokens = append(tokens, token)
}
}

return tokens
return tokens, comment
}

func toolVersionsToTokens(tv ToolVersions) []string {
return append([]string{tv.Name}, tv.Versions...)
}

func encodeLine(tokens []string, comment string) string {
tokensStr := strings.Join(tokens, " ")
if comment == "" {
if len(tokens) == 0 {
return ""
}
return tokensStr
}
return fmt.Sprintf("%s #%s", tokensStr, comment)
}
Loading

0 comments on commit f68b29b

Please sign in to comment.