Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add WaitForMsg function #311

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion exp/teatest/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func TestAppInteractive(t *testing.T) {
t.Fatalf("output does not match: expected %q", string(bts))
}

teatest.WaitFor(t, tm.Output(), func(out []byte) bool {
teatest.WaitForOutput(t, tm.Output(), func(out []byte) bool {
return bytes.Contains(out, []byte("This program will exit in 7 seconds"))
}, teatest.WithDuration(5*time.Second), teatest.WithCheckInterval(time.Millisecond*10))

Expand Down
31 changes: 31 additions & 0 deletions exp/teatest/msg_buffer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package teatest

import (
"sync"

tea "github.com/charmbracelet/bubbletea"
)

// msgBuffer stores messages for checking in WaitForMsg.
type msgBuffer struct {
msgs []tea.Msg
mu sync.Mutex
}

func (b *msgBuffer) append(msg tea.Msg) {
b.mu.Lock()
defer b.mu.Unlock()
b.msgs = append(b.msgs, msg)
}

// forEach executes the given function for each message while holding the lock.
func (b *msgBuffer) forEach(fn func(msg tea.Msg) bool) tea.Msg {
b.mu.Lock()
defer b.mu.Unlock()
for _, msg := range b.msgs {
if fn(msg) {
return msg
}
}
return nil
}
32 changes: 32 additions & 0 deletions exp/teatest/msg_capture.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package teatest

import (
tea "github.com/charmbracelet/bubbletea"
)

// msgCaptureModel wraps a model to capture messages.
type msgCaptureModel struct {
model tea.Model
buffer *msgBuffer
}

func (m msgCaptureModel) Init() tea.Cmd {
return m.model.Init()
}

func (m msgCaptureModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.buffer.append(msg)
model, cmd := m.model.Update(msg)
if wrappedModel, ok := model.(msgCaptureModel); ok {
return wrappedModel, cmd
}

return msgCaptureModel{
model: model,
buffer: m.buffer,
}, cmd
}

func (m msgCaptureModel) View() string {
return m.model.View()
}
59 changes: 54 additions & 5 deletions exp/teatest/teatest.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,22 @@ func WithDuration(d time.Duration) WaitForOption {
}
}

// WaitFor keeps reading from r until the condition matches.
// WaitForOutput keeps reading from r until the condition matches.
// Default duration is 1s, default check interval is 50ms.
// These defaults can be changed with WithDuration and WithCheckInterval.
func WaitFor(
func WaitForOutput(
tb testing.TB,
r io.Reader,
condition func(bts []byte) bool,
options ...WaitForOption,
) {
tb.Helper()
if err := doWaitFor(r, condition, options...); err != nil {
if err := doWaitForOutput(r, condition, options...); err != nil {
tb.Fatal(err)
}
}

func doWaitFor(r io.Reader, condition func(bts []byte) bool, options ...WaitForOption) error {
func doWaitForOutput(r io.Reader, condition func(bts []byte) bool, options ...WaitForOption) error {
wf := WaitingForContext{
Duration: time.Second,
CheckInterval: 50 * time.Millisecond, //nolint: mnd
Expand Down Expand Up @@ -114,6 +114,8 @@ type TestModel struct {

done sync.Once
doneCh chan bool

msgs *msgBuffer
}

// NewTestModel makes a new TestModel which can be used for tests.
Expand All @@ -123,11 +125,19 @@ func NewTestModel(tb testing.TB, m tea.Model, options ...TestOption) *TestModel
out: safe(bytes.NewBuffer(nil)),
modelCh: make(chan tea.Model, 1),
doneCh: make(chan bool, 1),
msgs: &msgBuffer{
msgs: make([]tea.Msg, 0),
},
}

wrappedModel := msgCaptureModel{
model: m,
buffer: tm.msgs,
}

//nolint: staticcheck
tm.program = tea.NewProgram(
m,
wrappedModel,
tea.WithInput(tm.in),
tea.WithOutput(tm.out),
tea.WithoutSignals(),
Expand Down Expand Up @@ -162,6 +172,45 @@ func NewTestModel(tb testing.TB, m tea.Model, options ...TestOption) *TestModel
return tm
}

// WaitForMsg keeps checking messages until the condition matches or timeout is reached.
// Default duration is 1s, default check interval is 50ms.
func (tm *TestModel) WaitForMsg(
tb testing.TB,
condition func(msg tea.Msg) bool,
options ...WaitForOption,
) tea.Msg {
tb.Helper()
msg, err := tm.doWaitForMsg(condition, options...)
if err != nil {
tb.Fatal(err)
}
return msg
}

func (tm *TestModel) doWaitForMsg(
condition func(msg tea.Msg) bool,
options ...WaitForOption,
) (tea.Msg, error) {
wf := WaitingForContext{
Duration: time.Second,
CheckInterval: 50 * time.Millisecond,
}

for _, opt := range options {
opt(&wf)
}

start := time.Now()
for time.Since(start) <= wf.Duration {
if msg := tm.msgs.forEach(condition); msg != nil {
return msg, nil
}
time.Sleep(wf.CheckInterval)
}

return nil, fmt.Errorf("WaitForMsg: condition not met after %s", wf.Duration)
}

func (tm *TestModel) waitDone(tb testing.TB, opts []FinalOpt) {
tm.done.Do(func() {
fopts := FinalOpts{}
Expand Down
4 changes: 2 additions & 2 deletions exp/teatest/teatest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

func TestWaitForErrorReader(t *testing.T) {
err := doWaitFor(iotest.ErrReader(fmt.Errorf("fake")), func(bts []byte) bool {
err := doWaitForOutput(iotest.ErrReader(fmt.Errorf("fake")), func(bts []byte) bool {
return true
}, WithDuration(time.Millisecond), WithCheckInterval(10*time.Microsecond))
if err == nil {
Expand All @@ -23,7 +23,7 @@ func TestWaitForErrorReader(t *testing.T) {
}

func TestWaitForTimeout(t *testing.T) {
err := doWaitFor(strings.NewReader("nope"), func(bts []byte) bool {
err := doWaitForOutput(strings.NewReader("nope"), func(bts []byte) bool {
return false
}, WithDuration(time.Millisecond), WithCheckInterval(10*time.Microsecond))
if err == nil {
Expand Down