Skip to content

Commit

Permalink
Refactor Normalize() in the command pkg (#506)
Browse files Browse the repository at this point in the history
Nothing fancy, just a small refactoring of `internal/command` and config
normalizers.

Co-authored-by: Sebastian Tiedtke <sebastiantiedtke@gmail.com>
  • Loading branch information
adambabik and sourishkrout authored Feb 20, 2024
1 parent a1df712 commit cff4cd9
Show file tree
Hide file tree
Showing 8 changed files with 238 additions and 212 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strings"

"github.com/pkg/errors"
"go.uber.org/multierr"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"

Expand Down Expand Up @@ -40,7 +41,11 @@ type argsNormalizer struct {
scriptFile *os.File
}

func (n *argsNormalizer) Normalize(cfg *Config) (*Config, error) {
func newArgsNormalizer(session *Session, logger *zap.Logger) configNormalizer {
return (&argsNormalizer{session: session, logger: logger}).Normalize
}

func (n *argsNormalizer) Normalize(cfg *Config) (*Config, func() error, error) {
args := append([]string{}, cfg.Arguments...)

switch cfg.Mode {
Expand All @@ -51,7 +56,7 @@ func (n *argsNormalizer) Normalize(cfg *Config) (*Config, error) {

if isShellLanguage(filepath.Base(cfg.ProgramName)) {
if err := n.inlineShell(cfg, &buf); err != nil {
return nil, err
return nil, nil, err
}
} else {
// Write the script from the commands or the script.
Expand All @@ -65,21 +70,22 @@ func (n *argsNormalizer) Normalize(cfg *Config) (*Config, error) {
}
}

// TODO(adamb): "-c" is not supported for all inline programs.
if val := buf.String(); val != "" {
args = append(args, "-c", val)
}

case *runnerv2alpha1.CommandMode_COMMAND_MODE_FILE.Enum():
if err := n.createTempDir(); err != nil {
return nil, err
return nil, nil, err
}

if err := n.createScriptFile(); err != nil {
return nil, err
return nil, nil, err
}

if err := n.writeScript([]byte(cfg.GetScript())); err != nil {
return nil, err
return nil, nil, err
}

// TODO(adamb): it's not always true that the script-based program
Expand All @@ -89,7 +95,7 @@ func (n *argsNormalizer) Normalize(cfg *Config) (*Config, error) {

result := proto.Clone(cfg).(*Config)
result.Arguments = args
return result, nil
return result, n.cleanup, nil
}

func (n *argsNormalizer) inlineShell(cfg *Config, buf *strings.Builder) error {
Expand Down Expand Up @@ -125,7 +131,17 @@ func (n *argsNormalizer) inlineShell(cfg *Config, buf *strings.Builder) error {
return nil
}

func (n *argsNormalizer) Cleanup() error {
func (n *argsNormalizer) cleanup() (result error) {
if err := n.collectEnv(); err != nil {
result = multierr.Append(result, err)
}
if err := n.removeTempDir(); err != nil {
result = multierr.Append(result, err)
}
return
}

func (n *argsNormalizer) removeTempDir() error {
if n.tempDir == "" {
return nil
}
Expand All @@ -139,7 +155,7 @@ func (n *argsNormalizer) Cleanup() error {
return nil
}

func (n *argsNormalizer) CollectEnv() error {
func (n *argsNormalizer) collectEnv() error {
if n.session == nil || !n.isEnvCollectable {
return nil
}
Expand Down Expand Up @@ -236,23 +252,17 @@ func splitNull(data []byte, atEOF bool) (advance int, token []byte, err error) {
return 0, nil, nil
}

type envSource func() []string
func shellOptionsFromProgram(programPath string) (res string) {
base := filepath.Base(programPath)
shell := base[:len(base)-len(filepath.Ext(base))]

type envNormalizer struct {
sources []envSource
}

func (n *envNormalizer) Normalize(cfg *Config) (*Config, error) {
result := proto.Clone(cfg).(*Config)

env := os.Environ()
env = append(env, cfg.Env...)

for _, source := range n.sources {
env = append(env, source()...)
// TODO(mxs): powershell and DOS are missing
switch shell {
case "zsh", "ksh", "bash":
res += "set -e -o pipefail"
case "sh":
res += "set -e"
}

result.Env = env

return result, nil
return
}
15 changes: 6 additions & 9 deletions internal/command/command_native.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,18 @@ func (c *NativeCommand) SetWinsize(rows, cols, x, y uint16) error {
}

func (c *NativeCommand) Start(ctx context.Context) (err error) {
argsNormalizer := &argsNormalizer{
session: c.opts.Session,
logger: c.logger,
}

cfg, err := normalizeConfig(
cfg, cleanups, err := normalizeConfig(
c.cfg,
argsNormalizer,
&envNormalizer{sources: []envSource{c.opts.Session.GetEnv}},
pathNormalizer,
modeNormalizer,
newArgsNormalizer(c.opts.Session, c.logger),
newEnvNormalizer(c.opts.Session.GetEnv),
)
if err != nil {
return
}

c.cleanFuncs = append(c.cleanFuncs, argsNormalizer.CollectEnv, argsNormalizer.Cleanup)
c.cleanFuncs = append(c.cleanFuncs, cleanups...)

stdin := c.opts.Stdin

Expand Down
15 changes: 6 additions & 9 deletions internal/command/command_virtual.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,21 +107,18 @@ func (c *VirtualCommand) Pid() int {
}

func (c *VirtualCommand) Start(ctx context.Context) (err error) {
argsNormalizer := &argsNormalizer{
session: c.opts.Session,
logger: c.logger,
}

cfg, err := normalizeConfig(
cfg, cleanups, err := normalizeConfig(
c.cfg,
argsNormalizer,
&envNormalizer{sources: []envSource{c.opts.Session.GetEnv}},
pathNormalizer,
modeNormalizer,
newArgsNormalizer(c.opts.Session, c.logger),
newEnvNormalizer(c.opts.Session.GetEnv),
)
if err != nil {
return
}

c.cleanFuncs = append(c.cleanFuncs, argsNormalizer.CollectEnv, argsNormalizer.Cleanup)
c.cleanFuncs = append(c.cleanFuncs, cleanups...)

c.pty, c.tty, err = pty.Open()
if err != nil {
Expand Down
188 changes: 18 additions & 170 deletions internal/command/config.go
Original file line number Diff line number Diff line change
@@ -1,35 +1,32 @@
package command

import (
"fmt"
"os/exec"
"path/filepath"
"strings"

"google.golang.org/protobuf/proto"

runnerv2alpha1 "github.com/stateful/runme/internal/gen/proto/go/runme/runner/v2alpha1"
)

type ErrUnsupportedLanguage struct {
langID string
}
// Config contains a serializable configuration for a command.
// It's agnostic to the runtime or particular execution settings.
type Config = runnerv2alpha1.ProgramConfig

func (e ErrUnsupportedLanguage) Error() string {
return fmt.Sprintf("unsupported language %s", e.langID)
}
type configNormalizer func(*Config) (*Config, func() error, error)

type ErrInterpretersNotFound struct {
interpreters []string
}
func normalizeConfig(cfg *Config, normalizers ...configNormalizer) (_ *Config, cleanups []func() error, err error) {
for _, normalizer := range normalizers {
var cleanup func() error

func (e ErrInterpretersNotFound) Error() string {
return fmt.Sprintf("unable to look up any of interpreters %q", e.interpreters)
}
cfg, cleanup, err = normalizer(cfg)
if err != nil {
return nil, nil, err
}

// Config contains a serializable configuration for a command.
// It's agnostic to the runtime or particular execution settings.
type Config = runnerv2alpha1.ProgramConfig
if cleanup != nil {
cleanups = append(cleanups, cleanup)
}
}
return cfg, cleanups, nil
}

// redactConfig returns a new Config instance and copies only fields considered safe.
// Useful for logging.
Expand All @@ -44,106 +41,7 @@ func redactConfig(cfg *Config) *Config {
}
}

func normalizeConfig(cfg *Config, extra ...configNormalizer) (*Config, error) {
normalizers := []configNormalizer{
&pathNormalizer{},
&modeNormalizer{},
}

normalizers = append(normalizers, extra...)

for _, normalizer := range normalizers {
var err error

if cfg, err = normalizer.Normalize(cfg); err != nil {
return nil, err
}
}

return cfg, nil
}

type configNormalizer interface {
Normalize(*Config) (*Config, error)
}

type pathNormalizer struct{}

func (n *pathNormalizer) Normalize(cfg *Config) (*Config, error) {
programPath, err := exec.LookPath(cfg.ProgramName)
if err == nil {
if programPath == cfg.ProgramName {
return cfg, nil
}

result := proto.Clone(cfg).(*Config)
result.ProgramName = programPath

return result, nil
}

interpreters := inferInterpreterFromLanguage(cfg.ProgramName)
if len(interpreters) == 0 {
return nil, &ErrUnsupportedLanguage{langID: cfg.ProgramName}
}

for _, interpreter := range interpreters {
program, args := parseInterpreter(interpreter)
if programPath, err := exec.LookPath(program); err == nil {
result := proto.Clone(cfg).(*Config)
result.ProgramName = programPath
result.Arguments = args
return result, nil
}
}

return nil, &ErrInterpretersNotFound{interpreters: interpreters}
}

type modeNormalizer struct{}

func (n *modeNormalizer) Normalize(cfg *Config) (*Config, error) {
if cfg.Mode != runnerv2alpha1.CommandMode_COMMAND_MODE_UNSPECIFIED {
return cfg, nil
}

result := proto.Clone(cfg).(*Config)

if isShellLanguage(filepath.Base(result.ProgramName)) {
result.Mode = runnerv2alpha1.CommandMode_COMMAND_MODE_INLINE
} else {
result.Mode = runnerv2alpha1.CommandMode_COMMAND_MODE_FILE
}

return result, nil
}

func prepareScriptFromLines(programPath string, lines []string) string {
var buf strings.Builder

for _, cmd := range lines {
_, _ = buf.WriteString(cmd)
_, _ = buf.WriteRune('\n')
}

return buf.String()
}

func shellOptionsFromProgram(programPath string) (res string) {
base := filepath.Base(programPath)
shell := base[:len(base)-len(filepath.Ext(base))]

// TODO(mxs): powershell and DOS are missing
switch shell {
case "zsh", "ksh", "bash":
res += "set -e -o pipefail"
case "sh":
res += "set -e"
}

return
}

// TODO(adamb): this function is used for two quite different inputs: program name and language ID.
func isShellLanguage(languageID string) bool {
switch strings.ToLower(languageID) {
// shellscripts
Expand All @@ -169,53 +67,3 @@ func isShellLanguage(languageID string) bool {
return false
}
}

// parseInterpreter handles cases when the interpreter is, for instance, "deno run".
// Only the first word is a program name and the rest is arguments.
func parseInterpreter(interpreter string) (program string, args []string) {
parts := strings.SplitN(interpreter, " ", 2)

if len(parts) > 0 {
program = parts[0]
}

if len(parts) > 1 {
args = strings.Split(parts[1], " ")
}

return
}

var interpreterByLanguageID = map[string][]string{
"js": {"node"},
"javascript": {"node"},
"jsx": {"node"},
"javascriptreact": {"node"},

"ts": {"ts-node", "deno run", "bun run"},
"typescript": {"ts-node", "deno run", "bun run"},
"tsx": {"ts-node", "deno run", "bun run"},
"typescriptreact": {"ts-node", "deno run", "bun run"},

"sh": {"bash", "sh"},
"bash": {"bash", "sh"},
"ksh": {"ksh"},
"zsh": {"zsh"},
"fish": {"fish"},
"powershell": {"powershell"},
"cmd": {"cmd"},
"dos": {"cmd"},
"shellscript": {"bash", "sh"},

"lua": {"lua"},
"perl": {"perl"},
"php": {"php"},
"python": {"python3", "python"},
"py": {"python3", "python"},
"ruby": {"ruby"},
"rb": {"ruby"},
}

func inferInterpreterFromLanguage(langID string) []string {
return interpreterByLanguageID[langID]
}
Loading

0 comments on commit cff4cd9

Please sign in to comment.