diff --git a/lib/client/api.go b/lib/client/api.go index 5e620379d5dc3..c438345af3c53 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -3501,7 +3501,7 @@ func playSession(sessionEvents []events.EventFields, stream []byte) error { ) // playback control goroutine go func() { - defer player.Stop() + defer player.RequestStop() var key [1]byte for { _, err := term.Stdin().Read(key[:]) diff --git a/lib/client/player.go b/lib/client/player.go index 872667aedec6d..3a9b042b72014 100644 --- a/lib/client/player.go +++ b/lib/client/player.go @@ -25,10 +25,13 @@ import ( "github.com/gravitational/teleport/lib/client/terminal" "github.com/gravitational/teleport/lib/events" + "github.com/jonboulle/clockwork" ) +type tshPlayerState int + const ( - stateStopped = iota + stateStopped tshPlayerState = iota stateStopping statePlaying ) @@ -37,40 +40,41 @@ const ( // and allows to control it type sessionPlayer struct { sync.Mutex + cond *sync.Cond + + state tshPlayerState + position int // position is the index of the last event successfully played back + + clock clockwork.Clock stream []byte sessionEvents []events.EventFields term *terminal.Terminal - state int - position int - // stopC is used to tell the caller that player has finished playing - stopC chan int + stopC chan int + stopOnce sync.Once } func newSessionPlayer(sessionEvents []events.EventFields, stream []byte, term *terminal.Terminal) *sessionPlayer { - return &sessionPlayer{ + p := &sessionPlayer{ + clock: clockwork.NewRealClock(), + position: -1, // position is the last successfully written event stream: stream, sessionEvents: sessionEvents, - stopC: make(chan int), term: term, + stopC: make(chan int), } + p.cond = sync.NewCond(p) + return p } func (p *sessionPlayer) Play() { p.playRange(0, 0) } -func (p *sessionPlayer) Stop() { +func (p *sessionPlayer) Stopped() bool { p.Lock() defer p.Unlock() - if p.stopC != nil { - close(p.stopC) - p.stopC = nil - } -} - -func (p *sessionPlayer) Stopped() bool { return p.state == stateStopped } @@ -78,7 +82,7 @@ func (p *sessionPlayer) Rewind() { p.Lock() defer p.Unlock() if p.state != stateStopped { - p.state = stateStopping + p.setState(stateStopping) p.waitUntil(stateStopped) } if p.position > 0 { @@ -86,11 +90,17 @@ func (p *sessionPlayer) Rewind() { } } +func (p *sessionPlayer) stopRequested() bool { + p.Lock() + defer p.Unlock() + return p.state == stateStopping +} + func (p *sessionPlayer) Forward() { p.Lock() defer p.Unlock() if p.state != stateStopped { - p.state = stateStopping + p.setState(stateStopping) p.waitUntil(stateStopped) } if p.position < len(p.sessionEvents) { @@ -102,20 +112,44 @@ func (p *sessionPlayer) TogglePause() { p.Lock() defer p.Unlock() if p.state == statePlaying { - p.state = stateStopping + p.setState(stateStopping) p.waitUntil(stateStopped) } else { - p.playRange(p.position, 0) + p.playRange(p.position+1, 0) p.waitUntil(statePlaying) } } -func (p *sessionPlayer) waitUntil(state int) { +// RequestStop makes an asynchronous request for the player to stop playing. +// Playback may not stop before this method returns. +func (p *sessionPlayer) RequestStop() { + p.Lock() + defer p.Unlock() + + switch p.state { + case stateStopped, stateStopping: + // do nothing if stop already in progress + default: + p.setState(stateStopping) + } +} + +// waitUntil waits for the specified state to be reached. +// Callers must hold the lock on p.Mutex before calling. +func (p *sessionPlayer) waitUntil(state tshPlayerState) { for state != p.state { - time.Sleep(time.Millisecond) + p.cond.Wait() } } +// setState sets the current player state and notifies any +// goroutines waiting in waitUntil(). Callers must hold the +// lock on p.Mutex before calling. +func (p *sessionPlayer) setState(state tshPlayerState) { + p.state = state + p.cond.Broadcast() +} + // timestampFrame prints 'event timestamp' in the top right corner of the // terminal after playing every 'print' event func timestampFrame(term *terminal.Terminal, message string) { @@ -146,7 +180,9 @@ func timestampFrame(term *terminal.Terminal, message string) { // that playback starts from there. func (p *sessionPlayer) playRange(from, to int) { if to > len(p.sessionEvents) || from < 0 { - p.state = stateStopped + p.Lock() + p.setState(stateStopped) + p.Unlock() return } if to == 0 { @@ -154,48 +190,42 @@ func (p *sessionPlayer) playRange(from, to int) { } // clear screen between runs: os.Stdout.Write([]byte("\x1bc")) - // wait: waits between events during playback - prev := time.Duration(0) - wait := func(i int, e events.EventFields) { - ms := time.Duration(e.GetInt("ms")) - // before "from"? play that instantly: - if i >= from { - delay := ms - prev - // make playback smoother: - if delay < 10 { - delay = 0 - } - if delay > 250 && delay < 500 { - delay = 250 - } - if delay > 500 && delay < 1000 { - delay = 500 - } - if delay > 1000 { - delay = 1000 - } - timestampFrame(p.term, e.GetString("time")) - time.Sleep(time.Millisecond * delay) - } - prev = ms - } + // playback goroutine: go func() { + var i int + defer func() { - p.state = stateStopped + p.Lock() + p.setState(stateStopped) + p.Unlock() + + // played last event? + if i == len(p.sessionEvents) { + p.stopOnce.Do(func() { close(p.stopC) }) + } }() - p.state = statePlaying - i, offset, bytes := 0, 0, 0 + + p.Lock() + p.setState(statePlaying) + p.Unlock() + + prev := time.Duration(0) + offset, bytes := 0, 0 for i = 0; i < to; i++ { - if p.state == stateStopping { + if p.stopRequested() { return } + e := p.sessionEvents[i] switch e.GetString(events.EventType) { // 'print' event (output) case events.SessionPrintEvent: - wait(i, e) + // delay is only necessary once we've caught up to the "from" event + if i >= from { + prev = p.applyDelay(prev, e) + } offset = e.GetInt("offset") bytes = e.GetInt("bytes") os.Stdout.Write(p.stream[offset : offset+bytes]) @@ -211,11 +241,32 @@ func (p *sessionPlayer) playRange(from, to int) { default: continue } + p.Lock() p.position = i - } - // played last event? - if i == len(p.sessionEvents) { - p.Stop() + p.Unlock() } }() } + +// applyDelay waits until it is time to play back the current event. +// It returns the duration from the start of the session up until the current event. +func (p *sessionPlayer) applyDelay(previousTimestamp time.Duration, e events.EventFields) time.Duration { + eventTime := time.Duration(e.GetInt("ms") * int(time.Millisecond)) + delay := eventTime - previousTimestamp + + // make playback smoother: + switch { + case delay < 10*time.Millisecond: + delay = 0 + case delay > 250*time.Millisecond && delay < 500*time.Millisecond: + delay = 250 * time.Millisecond + case delay > 500*time.Millisecond && delay < 1*time.Second: + delay = 500 * time.Millisecond + case delay > time.Second: + delay = time.Second + } + + timestampFrame(p.term, e.GetString("time")) + p.clock.Sleep(delay) + return eventTime +} diff --git a/lib/client/player_test.go b/lib/client/player_test.go new file mode 100644 index 0000000000000..8abfb99afa184 --- /dev/null +++ b/lib/client/player_test.go @@ -0,0 +1,158 @@ +/* +Copyright 2022 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "bytes" + "testing" + "time" + + "github.com/gravitational/teleport/lib/client/terminal" + "github.com/gravitational/teleport/lib/events" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +// TestEmptyPlay verifies that a playback of 0 events +// immediately transitions to a stopped state. +func TestEmptyPlay(t *testing.T) { + c := clockwork.NewFakeClock() + p := newSessionPlayer(nil, nil, testTerm(t)) + p.clock = c + + p.Play() + + select { + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for player to complete") + case <-p.stopC: + } + + require.True(t, p.Stopped()) +} + +// TestStop verifies that we can stop playback. +func TestStop(t *testing.T) { + c := clockwork.NewFakeClock() + events := printEvents(100, 200) + p := newSessionPlayer(events, nil, testTerm(t)) + p.clock = c + + p.Play() + + // wait for player to see the first event and apply the delay + c.BlockUntil(1) + + p.RequestStop() + + // advance the clock: + // at this point, the player will write the first event and then + // see that we requested a stop + c.Advance(100 * time.Millisecond) + + require.Eventually(t, p.Stopped, 2*time.Second, 200*time.Millisecond) +} + +// TestPlayPause verifies the play/pause functionality. +func TestPlayPause(t *testing.T) { + c := clockwork.NewFakeClock() + + // in this test, we let the player play 2 of the 3 events, + // then pause it and verify the pause state before resuming + // playback for the final event. + events := printEvents(100, 200, 300) + var stream []byte // intentionally empty, we dont care about stream contents here + p := newSessionPlayer(events, stream, testTerm(t)) + p.clock = c + + p.Play() + + // wait for player to see the first event and apply the delay + c.BlockUntil(1) + + // advance the clock: + // at this point, the player will write the first event + c.Advance(100 * time.Millisecond) + + // wait for the player to sleep on the 2nd event + c.BlockUntil(1) + + // pause playback + // note: we don't use p.TogglePause here, as it waits for the state transition, + // and the state won't transition proceed until we advance the clock + p.Lock() + p.setState(stateStopping) + p.Unlock() + + // advance the clock again: + // the player will write the second event and + // then realize that it's been asked to pause + c.Advance(100 * time.Millisecond) + + p.Lock() + p.waitUntil(stateStopped) + p.Unlock() + + ch := make(chan struct{}) + go func() { + // resume playback + p.TogglePause() + ch <- struct{}{} + }() + + // playback should resume for the 3rd and final event: + // in this case, the first two events are written immediately without delay, + // and we block here until the player is sleeping prior to the 3rd event + c.BlockUntil(1) + + // make sure that we've resumed + <-ch + require.False(t, p.Stopped()) + + // advance the clock a final time, forcing the player to write the last event + // note: on the resume, we play the successful events immediately, and then sleep + // up to the resume point, which is why we advance by 300ms here + c.Advance(300 * time.Millisecond) + + select { + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for player to complete") + case <-p.stopC: + } + require.True(t, p.Stopped()) +} + +func testTerm(t *testing.T) *terminal.Terminal { + t.Helper() + term, err := terminal.New(bytes.NewReader(nil), &bytes.Buffer{}, &bytes.Buffer{}) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, term.Close()) + }) + return term +} + +func printEvents(delays ...int) []events.EventFields { + result := make([]events.EventFields, len(delays)) + for i := range result { + result[i] = events.EventFields{ + events.EventType: events.SessionPrintEvent, + "ms": delays[i], + } + } + return result +}