Skip to content

Commit

Permalink
Optionally allow Stateful to be used (#581)
Browse files Browse the repository at this point in the history
Simple `--stateful` flag for now.
  • Loading branch information
sourishkrout authored May 22, 2024
1 parent c08cd96 commit 5a8ba35
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 46 deletions.
2 changes: 0 additions & 2 deletions internal/cmd/api_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
"github.com/stateful/runme/v3/internal/auth"
"github.com/stateful/runme/v3/internal/client"
"github.com/stateful/runme/v3/internal/client/graphql/query"
"github.com/stateful/runme/v3/internal/extension"
"github.com/stateful/runme/v3/internal/log"
"github.com/stateful/runme/v3/internal/version"
"golang.org/x/oauth2"
Expand Down Expand Up @@ -89,7 +88,6 @@ var (
authEnv auth.Env // overwritten only in unit tests; when nil a default env will be used
authAuthorizer auth.Authorizer // overwritten only in unit tests
tokenStorage = &auth.DiskStorage{}
extensioner = extension.Default()
)

// authorizerWithEnv is a decorator that can return a token
Expand Down
2 changes: 1 addition & 1 deletion internal/cmd/code_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func codeServerCmd() *cobra.Command {
}
}

if _, err := runCodeServerCommand(cmd, execFile, false, "--install-extension", "stateful.runme", "--force"); err != nil {
if _, err := runCodeServerCommand(cmd, execFile, false, "--install-extension", fExtensionHandle, "--force"); err != nil {
return errors.Wrap(err, "failed to install extension to code-server")
}

Expand Down
42 changes: 22 additions & 20 deletions internal/cmd/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,35 +51,37 @@ type extUpdateMsg struct {
}

type extensionerModel struct {
force bool
loading bool
loadingMsg string
spinner spinner.Model
successMsg string
prompting bool
prompt prompt.QuestionModel
log *zap.Logger
force bool
extensioner extension.Extensioner
loading bool
loadingMsg string
spinner spinner.Model
successMsg string
prompting bool
prompt prompt.QuestionModel
log *zap.Logger
}

func newExtensionerModel(force bool) extensionerModel {
s := spinner.New()
s.Spinner = spinner.Line

return extensionerModel{
force: force,
prompt: prompt.NewQuestionModel("Do you want to install the extension?"),
spinner: s,
loading: true,
loadingMsg: "checking status of the extension...",
log: log.Get().Named("command.extensionerModel"),
force: force,
extensioner: extension.New(fStateful),
prompt: prompt.NewQuestionModel("Do you want to install the extension?"),
spinner: s,
loading: true,
loadingMsg: "checking status of the extension...",
log: log.Get().Named("command.extensionerModel"),
}
}

func (m extensionerModel) Init() tea.Cmd {
return tea.Batch(
m.prompt.Init(),
func() tea.Msg {
fullName, installed, err := extensioner.IsInstalled()
fullName, installed, err := m.extensioner.IsInstalled()
return extCheckMsg{
Installed: installed,
Name: fullName,
Expand All @@ -101,7 +103,7 @@ func (m extensionerModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if msg.Value {
return m, func() tea.Msg { return prepExtInstallationMsg{} }
}
m.successMsg = fmt.Sprintf("You can install the extension manually using: %q", extension.InstallCommand())
m.successMsg = fmt.Sprintf("You can install the extension manually using: %q", m.extensioner.InstallCommand())
return m, tea.Quit

case prepExtInstallationMsg:
Expand All @@ -125,11 +127,11 @@ func (m extensionerModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.loading = true
m.loadingMsg = "updating the extension..."
return m, func() tea.Msg {
if err := extensioner.Update(); err != nil {
if err := m.extensioner.Update(); err != nil {
return extUpdateMsg{Err: err}
}

updatedFullName, _, err := extensioner.IsInstalled()
updatedFullName, _, err := m.extensioner.IsInstalled()
if err != nil {
return extUpdateMsg{
Err: err,
Expand Down Expand Up @@ -176,10 +178,10 @@ func (m extensionerModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
}

func (m extensionerModel) installExtension() tea.Msg {
if err := extensioner.Install(); err != nil {
if err := m.extensioner.Install(); err != nil {
return extUpdateMsg{Err: err}
}
installedFullName, _, err := extensioner.IsInstalled()
installedFullName, _, err := m.extensioner.IsInstalled()
if err != nil {
return extUpdateMsg{Err: err}
}
Expand Down
11 changes: 11 additions & 0 deletions internal/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/spf13/pflag"

"github.com/stateful/runme/v3/internal/cmd/beta"
"github.com/stateful/runme/v3/internal/extension"
)

var (
Expand All @@ -24,6 +25,8 @@ var (
fInsecure bool
fLogEnabled bool
fLogFilePath string
fExtensionHandle string
fStateful bool
)

func Root() *cobra.Command {
Expand Down Expand Up @@ -56,6 +59,12 @@ func Root() *cobra.Command {
if fFileMode && !cmd.Flags().Changed("allow-unnamed") {
fAllowUnnamed = true
}

if fExtensionHandle == "" && !fStateful {
fExtensionHandle = extension.DefaultExtensionName
} else {
fExtensionHandle = extension.PlatformExtensionName
}
},
ValidArgsFunction: func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
return nil, cobra.ShellCompDirectiveNoFileComp
Expand All @@ -80,6 +89,8 @@ func Root() *cobra.Command {
pflags.BoolVar(&fLogEnabled, "log", false, "Enable logging")
pflags.StringVar(&fLogFilePath, "log-file", filepath.Join(getTempDir(), "runme.log"), "Log file path")

pflags.BoolVar(&fStateful, "stateful", false, "Set Stateful instead of the Runme default")

setAPIFlags(pflags)

tuiCmd := tuiCmd()
Expand Down
64 changes: 44 additions & 20 deletions internal/extension/extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,54 @@ import (
"go.uber.org/zap"
)

const defaultName = "stateful.runme"

// Order matters. Default extension name should be first and legacies behind.
// It's the extension's job to make sure the newest version is used.
var allExtensionNames = []string{defaultName}
const (
DefaultExtensionName = "stateful.runme"
PlatformExtensionName = "stateful.platform"
)

//go:generate mockgen --build_flags=--mod=mod -destination=./extension_mock_gen.go -package=extension . Extensioner
type Extensioner interface {
IsInstalled() (string, bool, error)
Install() error
InstallCommand() string
Update() error
}

func Default() Extensioner {
return &extensioner{}
func New(isStateful bool) Extensioner {
if isStateful {
return newStateful()
}

return newDefault()
}

func newDefault() Extensioner {
return &extensioner{
extensionName: DefaultExtensionName,
allExtensionNames: []string{DefaultExtensionName},
}
}

func newStateful() Extensioner {
return &extensioner{
extensionName: PlatformExtensionName,
allExtensionNames: []string{PlatformExtensionName},
}
}

type extensioner struct{}
type extensioner struct {
extensionName string
allExtensionNames []string
}

func (extensioner) IsInstalled() (string, bool, error) { return IsInstalled() }
func (extensioner) Install() error { return Install() }
func (extensioner) Update() error { return Update() }
func (ext *extensioner) IsInstalled() (string, bool, error) {
return IsInstalled(ext.allExtensionNames)
}
func (ext *extensioner) Install() error { return Install(ext.extensionName) }
func (ext *extensioner) InstallCommand() string { return InstallCommand(ext.extensionName) }
func (ext *extensioner) Update() error { return Update(ext.extensionName) }

func IsInstalled() (string, bool, error) {
func IsInstalled(allExtensionNames []string) (string, bool, error) {
extensions, err := listExtensions()
if err != nil {
return "", false, err
Expand All @@ -44,19 +68,19 @@ func IsInstalled() (string, bool, error) {
return ext.String(), found, err
}

func InstallCommand() string {
return strings.Join(installCommand(false), " ")
func InstallCommand(extensionName string) string {
return strings.Join(installCommand(extensionName, false), " ")
}

func Install() error {
cmdSlice := installCommand(false)
func Install(extensionName string) error {
cmdSlice := installCommand(extensionName, false)
cmd := exec.Command(cmdSlice[0], cmdSlice[1:]...)
// TODO(adamb): error written to stderr is not returned
return cmd.Run()
}

func Update() error {
cmdSlice := installCommand(true)
func Update(extensionName string) error {
cmdSlice := installCommand(extensionName, true)
cmd := exec.Command(cmdSlice[0], cmdSlice[1:]...)
// TODO(adamb): error written to stderr is not returned
return cmd.Run()
Expand All @@ -83,14 +107,14 @@ func isInstalled(extensions []ext, searchedNames []string) (ext, bool, error) {
return ext{}, false, nil
}

func installCommand(force bool) []string {
func installCommand(extensionName string, force bool) []string {
cmd := []string{"code", "--install-extension"}
// --force will update if the extension is already installed.
// If it is not installed, --force has no effect.
if force {
cmd = append(cmd, "--force")
}
return append(cmd, defaultName)
return append(cmd, extensionName)
}

func isVSCodeInstalled() bool {
Expand Down
16 changes: 15 additions & 1 deletion internal/extension/extension_mock_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions internal/extension/extension_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"github.com/stretchr/testify/require"
)

var allExtensionNames = []string{DefaultExtensionName}

func Test_commandExists(t *testing.T) {
if runtime.GOOS == "windows" {
result := commandExists("cmd.exe")
Expand All @@ -29,7 +31,7 @@ func Test_isExtensionInstalled(t *testing.T) {
// Legacy extension installed.
var extensions []ext
for _, name := range allExtensionNames {
if name != defaultName {
if name != DefaultExtensionName {
extensions = append(extensions, ext{Name: name})
}
}
Expand All @@ -40,7 +42,7 @@ func Test_isExtensionInstalled(t *testing.T) {
require.Empty(t, version)

// Default extension installed.
extensions = []ext{{Name: defaultName}}
extensions = []ext{{Name: DefaultExtensionName}}
version, result, err = isInstalled(extensions, allExtensionNames)
require.NoError(t, err)
require.True(t, result)
Expand Down

0 comments on commit 5a8ba35

Please sign in to comment.