Skip to content

Commit

Permalink
Use PATH from session to look for programs (#555)
Browse files Browse the repository at this point in the history
If `PATH` is present in the session, use it to look for the program
paths. This is especially important for situation when within a session
a virtual env is created and consecutive code blocks should be executed
within it.

Fixes #552

Port this change to #548.
  • Loading branch information
adambabik authored Apr 17, 2024
1 parent d09f03e commit 5ebe0ff
Show file tree
Hide file tree
Showing 12 changed files with 194 additions and 15 deletions.
2 changes: 2 additions & 0 deletions internal/cmd/code_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/pkg/errors"
"github.com/spf13/cobra"
"github.com/stateful/runme/v3/internal/runner"
"github.com/stateful/runme/v3/internal/system"
"github.com/stateful/runme/v3/internal/tui"
"go.uber.org/zap"
)
Expand Down Expand Up @@ -101,6 +102,7 @@ func codeServerCmd() *cobra.Command {
Stdout: cmd.OutOrStdout(),
Stderr: cmd.ErrOrStderr(),
Stdin: stdin,
System: system.Default,
Logger: zap.NewNop(),
}

Expand Down
2 changes: 2 additions & 0 deletions internal/runner/client/client_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/stateful/runme/v3/internal/document"
"github.com/stateful/runme/v3/internal/project"
"github.com/stateful/runme/v3/internal/runner"
"github.com/stateful/runme/v3/internal/system"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -91,6 +92,7 @@ func (r *LocalRunner) newExecutable(task project.Task) (runner.Executable, error
Stdout: r.stdout,
Stderr: r.stderr,
Session: r.session,
System: system.Default,
Logger: r.logger,
}

Expand Down
22 changes: 17 additions & 5 deletions internal/runner/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ import (

"github.com/creack/pty"
"github.com/pkg/errors"
"github.com/stateful/runme/v3/internal/ulid"
"go.uber.org/multierr"
"go.uber.org/zap"

"github.com/stateful/runme/v3/internal/system"
"github.com/stateful/runme/v3/internal/ulid"
)

const (
Expand Down Expand Up @@ -102,10 +104,20 @@ type commandConfig struct {
}

func newCommand(cfg *commandConfig) (*command, error) {
// If PATH is set in the session, use it in the system
// so that program paths can be resolved correctly.
// This is especially important for virtual envs.
sys := system.Default
if cfg.Session != nil {
if pathEnv, err := cfg.Session.envStorer.getEnv("PATH"); err == nil && pathEnv != "" {
sys = system.New(system.WithPathEnvGetter(func() string { return pathEnv }))
}
}

programName, initialArgs := parseFileProgram(cfg.ProgramName)
args := initialArgs

programPath, initialArgs, err := inferFileProgram(programName, cfg.LanguageID)
programPath, initialArgs, err := inferFileProgram(sys, programName, cfg.LanguageID)
args = append(args, initialArgs...)
if err != nil {
return nil, errors.WithStack(err)
Expand Down Expand Up @@ -658,9 +670,9 @@ func parseFileProgram(programPath string) (program string, args []string) {
return
}

func inferFileProgram(programPath string, languageID string) (interpreter string, args []string, err error) {
func inferFileProgram(sys *system.System, programPath string, languageID string) (interpreter string, args []string, err error) {
if programPath != "" {
res, err := exec.LookPath(programPath)
res, err := sys.LookPath(programPath)
if err != nil {
return "", []string{}, ErrInvalidProgram{
Program: programPath,
Expand All @@ -672,7 +684,7 @@ func inferFileProgram(programPath string, languageID string) (interpreter string

for _, candidate := range programByLanguageID[languageID] {
program, args := parseFileProgram(candidate)
res, err := exec.LookPath(program)
res, err := sys.LookPath(program)
if err == nil {
return res, args, nil
}
Expand Down
6 changes: 6 additions & 0 deletions internal/runner/env_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ func (s *envStore) Add(envs ...string) *envStore {
return s
}

func (s *envStore) Get(k string) string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.values[k]
}

func getEnvSizeContribution(k, v string) int {
// +2 for the '=' and '\0' separators
return len(k) + len(v) + 2
Expand Down
5 changes: 4 additions & 1 deletion internal/runner/executable.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"context"
"io"

"github.com/stateful/runme/v3/internal/executable"
"go.uber.org/zap"

"github.com/stateful/runme/v3/internal/executable"
"github.com/stateful/runme/v3/internal/system"
)

type Executable interface {
Expand All @@ -24,6 +26,7 @@ type ExecutableConfig struct {
PreEnv []string
PostEnv []string
Session *Session
System *system.System
Logger *zap.Logger
}

Expand Down
4 changes: 2 additions & 2 deletions internal/runner/go.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type Go struct {
var _ Executable = (*Go)(nil)

func (g Go) DryRun(ctx context.Context, w io.Writer) {
_, err := exec.LookPath("go")
_, err := g.System.LookPath("go")
if err != nil {
_, _ = fmt.Fprintf(w, "failed to find %q executable: %s\n", "go", err)
}
Expand All @@ -30,7 +30,7 @@ func (g Go) DryRun(ctx context.Context, w io.Writer) {
}

func (g *Go) Run(ctx context.Context) error {
executable, err := exec.LookPath("go")
executable, err := g.System.LookPath("go")
if err != nil {
return errors.Wrapf(err, "failed to find %q executable", "go")
}
Expand Down
23 changes: 23 additions & 0 deletions internal/runner/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package runner
import (
"context"
"fmt"
"strings"
"sync"

lru "github.com/hashicorp/golang-lru/v2"
Expand All @@ -15,6 +16,7 @@ import (
var owlStoreDefault = false

type envStorer interface {
getEnv(string) (string, error)
envs() ([]string, error)
sensitiveEnvKeys() ([]string, error)
addEnvs(envs []string) error
Expand Down Expand Up @@ -139,6 +141,10 @@ func (es *runnerEnvStorer) sensitiveEnvKeys() ([]string, error) {
return []string{}, nil
}

func (es *runnerEnvStorer) getEnv(name string) (string, error) {
return es.envStore.Get(name), nil
}

func (es *runnerEnvStorer) envs() ([]string, error) {
envs, err := es.envStore.Values()
if err != nil {
Expand Down Expand Up @@ -309,6 +315,23 @@ func (es *owlEnvStorer) addEnvs(envs []string) error {
return nil
}

func (es *owlEnvStorer) getEnv(name string) (string, error) {
env, err := es.owlStore.InsecureValues()
if err != nil {
return "", err
}

prefix := name + "="

for _, item := range env {
if strings.HasPrefix(item, prefix) {
return item[len(prefix):], nil
}
}

return "", nil
}

func (es *owlEnvStorer) sensitiveEnvKeys() ([]string, error) {
vals, err := es.owlStore.SensitiveKeys()
if err != nil {
Expand Down
16 changes: 9 additions & 7 deletions internal/runner/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ import (
"syscall"

"github.com/pkg/errors"

"github.com/stateful/runme/v3/internal/document"
"github.com/stateful/runme/v3/internal/system"
)

type Shell struct {
Expand All @@ -27,7 +29,7 @@ type Shell struct {
var _ Executable = (*Shell)(nil)

func (s Shell) ProgramPath() string {
return ResolveShellPath(s.CustomShell)
return resolveShellPath(s.System, s.CustomShell)
}

func (s Shell) ShellType() string {
Expand Down Expand Up @@ -134,7 +136,7 @@ func IsShellLanguage(languageID string) bool {

func GetCellProgram(languageID string, customShell string, cell *document.CodeBlock) (program string, commandMode CommandMode) {
if IsShellLanguage(languageID) {
program = ResolveShellPath(customShell)
program = customShell
commandMode = CommandModeInlineShell
} else {
commandMode = CommandModeTempFile
Expand All @@ -147,22 +149,22 @@ func GetCellProgram(languageID string, customShell string, cell *document.CodeBl
return
}

func ResolveShellPath(customShell string) string {
func resolveShellPath(sys *system.System, customShell string) string {
if customShell != "" {
if path, err := exec.LookPath(customShell); err == nil {
if path, err := sys.LookPath(customShell); err == nil {
return path
}
}

return GlobalShellPath()
return globalShellPath(sys)
}

func GlobalShellPath() string {
func globalShellPath(sys *system.System) string {
shell, ok := os.LookupEnv("SHELL")
if !ok {
shell = "sh"
}
if path, err := exec.LookPath(shell); err == nil {
if path, err := sys.LookPath(shell); err == nil {
return path
}
return "/bin/sh"
Expand Down
39 changes: 39 additions & 0 deletions internal/system/lookpath.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package system

import (
"os"
)

var Default = newDefault()

func newDefault() *System {
return &System{
getPathEnv: func() string { return os.Getenv("PATH") },
}
}

type Option func(*System)

func WithPathEnvGetter(fn func() string) Option {
return func(s *System) {
s.getPathEnv = fn
}
}

type System struct {
getPathEnv func() string
}

func New(opts ...Option) *System {
s := newDefault()

for _, opt := range opts {
opt(s)
}

return s
}

func (s *System) LookPath(file string) (string, error) {
return lookPath(s.getPathEnv(), file)
}
30 changes: 30 additions & 0 deletions internal/system/lookpath_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//go:build unix

// TODO(adamb): remove the build flag when [System.LookPath] is implemented for Windows.

package system

import (
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestLookPath(t *testing.T) {
tmp := t.TempDir()
myBinaryPath := filepath.Join(tmp, "my-binary")

// Create an empty file with execute permission.
err := os.WriteFile(myBinaryPath, []byte{}, 0o111)
require.NoError(t, err)

s := New(
WithPathEnvGetter(func() string { return tmp }),
)
path, err := s.LookPath("my-binary")
require.NoError(t, err)
assert.Equal(t, myBinaryPath, path)
}
51 changes: 51 additions & 0 deletions internal/system/lookpath_unix.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//go:build unix

package system

import (
"io/fs"
"os"
"os/exec"
"path/filepath"
"strings"
"syscall"
)

func lookPath(pathEnv, file string) (string, error) {
if strings.Contains(file, "/") {
err := findExecutable(file)
if err == nil {
return file, nil
}
return "", &exec.Error{Name: file, Err: err}
}
for _, dir := range filepath.SplitList(pathEnv) {
if dir == "" {
// Unix shell semantics: path element "" means "."
dir = "."
}
path := filepath.Join(dir, file)
if err := findExecutable(path); err == nil {
if !filepath.IsAbs(path) {
return path, &exec.Error{Name: file, Err: exec.ErrDot}
}
return path, nil
}
}
return "", &exec.Error{Name: file, Err: exec.ErrNotFound}
}

func findExecutable(file string) error {
d, err := os.Stat(file)
if err != nil {
return err
}
m := d.Mode()
if m.IsDir() {
return syscall.EISDIR
}
if m&0o111 != 0 {
return nil
}
return fs.ErrPermission
}
9 changes: 9 additions & 0 deletions internal/system/lookpath_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package system

import "os/exec"

func lookPath(_, file string) (string, error) {
// TODO(adamb): implement this for Windows.
// Check out https://github.com/golang/go/blob/master/src/os/exec/lp_windows.go.
return exec.LookPath(file)
}

0 comments on commit 5ebe0ff

Please sign in to comment.