From 789cb79816132d0b0ceeb955f2abc3b370601877 Mon Sep 17 00:00:00 2001 From: Zac Bergquist Date: Tue, 29 Mar 2022 16:25:27 -0600 Subject: [PATCH] Fix tsh player issues (#11491) This commit fixes race conditions in the tsh session player by using a condition variable to detect state changes rather than unsafely polling a variable that is written by a separate goroutine. In addition, fix an off by one error when resuming playback after pausing. The player's position variable has always stored the index of the last succesfully played event, so when we resume playback we should start at position+1 to not re-play the previous event twice. Fixes #11479 --- lib/client/api.go | 2 +- lib/client/player.go | 163 +++++++++++++++++++++++++------------- lib/client/player_test.go | 158 ++++++++++++++++++++++++++++++++++++ 3 files changed, 266 insertions(+), 57 deletions(-) create mode 100644 lib/client/player_test.go diff --git a/lib/client/api.go b/lib/client/api.go index 5e620379d5dc3..c438345af3c53 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -3501,7 +3501,7 @@ func playSession(sessionEvents []events.EventFields, stream []byte) error { ) // playback control goroutine go func() { - defer player.Stop() + defer player.RequestStop() var key [1]byte for { _, err := term.Stdin().Read(key[:]) diff --git a/lib/client/player.go b/lib/client/player.go index 872667aedec6d..3a9b042b72014 100644 --- a/lib/client/player.go +++ b/lib/client/player.go @@ -25,10 +25,13 @@ import ( "github.com/gravitational/teleport/lib/client/terminal" "github.com/gravitational/teleport/lib/events" + "github.com/jonboulle/clockwork" ) +type tshPlayerState int + const ( - stateStopped = iota + stateStopped tshPlayerState = iota stateStopping statePlaying ) @@ -37,40 +40,41 @@ const ( // and allows to control it type sessionPlayer struct { sync.Mutex + cond *sync.Cond + + state tshPlayerState + position int // position is the index of the last event successfully played back + + clock clockwork.Clock stream []byte sessionEvents []events.EventFields term *terminal.Terminal - state int - position int - // stopC is used to tell the caller that player has finished playing - stopC chan int + stopC chan int + stopOnce sync.Once } func newSessionPlayer(sessionEvents []events.EventFields, stream []byte, term *terminal.Terminal) *sessionPlayer { - return &sessionPlayer{ + p := &sessionPlayer{ + clock: clockwork.NewRealClock(), + position: -1, // position is the last successfully written event stream: stream, sessionEvents: sessionEvents, - stopC: make(chan int), term: term, + stopC: make(chan int), } + p.cond = sync.NewCond(p) + return p } func (p *sessionPlayer) Play() { p.playRange(0, 0) } -func (p *sessionPlayer) Stop() { +func (p *sessionPlayer) Stopped() bool { p.Lock() defer p.Unlock() - if p.stopC != nil { - close(p.stopC) - p.stopC = nil - } -} - -func (p *sessionPlayer) Stopped() bool { return p.state == stateStopped } @@ -78,7 +82,7 @@ func (p *sessionPlayer) Rewind() { p.Lock() defer p.Unlock() if p.state != stateStopped { - p.state = stateStopping + p.setState(stateStopping) p.waitUntil(stateStopped) } if p.position > 0 { @@ -86,11 +90,17 @@ func (p *sessionPlayer) Rewind() { } } +func (p *sessionPlayer) stopRequested() bool { + p.Lock() + defer p.Unlock() + return p.state == stateStopping +} + func (p *sessionPlayer) Forward() { p.Lock() defer p.Unlock() if p.state != stateStopped { - p.state = stateStopping + p.setState(stateStopping) p.waitUntil(stateStopped) } if p.position < len(p.sessionEvents) { @@ -102,20 +112,44 @@ func (p *sessionPlayer) TogglePause() { p.Lock() defer p.Unlock() if p.state == statePlaying { - p.state = stateStopping + p.setState(stateStopping) p.waitUntil(stateStopped) } else { - p.playRange(p.position, 0) + p.playRange(p.position+1, 0) p.waitUntil(statePlaying) } } -func (p *sessionPlayer) waitUntil(state int) { +// RequestStop makes an asynchronous request for the player to stop playing. +// Playback may not stop before this method returns. +func (p *sessionPlayer) RequestStop() { + p.Lock() + defer p.Unlock() + + switch p.state { + case stateStopped, stateStopping: + // do nothing if stop already in progress + default: + p.setState(stateStopping) + } +} + +// waitUntil waits for the specified state to be reached. +// Callers must hold the lock on p.Mutex before calling. +func (p *sessionPlayer) waitUntil(state tshPlayerState) { for state != p.state { - time.Sleep(time.Millisecond) + p.cond.Wait() } } +// setState sets the current player state and notifies any +// goroutines waiting in waitUntil(). Callers must hold the +// lock on p.Mutex before calling. +func (p *sessionPlayer) setState(state tshPlayerState) { + p.state = state + p.cond.Broadcast() +} + // timestampFrame prints 'event timestamp' in the top right corner of the // terminal after playing every 'print' event func timestampFrame(term *terminal.Terminal, message string) { @@ -146,7 +180,9 @@ func timestampFrame(term *terminal.Terminal, message string) { // that playback starts from there. func (p *sessionPlayer) playRange(from, to int) { if to > len(p.sessionEvents) || from < 0 { - p.state = stateStopped + p.Lock() + p.setState(stateStopped) + p.Unlock() return } if to == 0 { @@ -154,48 +190,42 @@ func (p *sessionPlayer) playRange(from, to int) { } // clear screen between runs: os.Stdout.Write([]byte("\x1bc")) - // wait: waits between events during playback - prev := time.Duration(0) - wait := func(i int, e events.EventFields) { - ms := time.Duration(e.GetInt("ms")) - // before "from"? play that instantly: - if i >= from { - delay := ms - prev - // make playback smoother: - if delay < 10 { - delay = 0 - } - if delay > 250 && delay < 500 { - delay = 250 - } - if delay > 500 && delay < 1000 { - delay = 500 - } - if delay > 1000 { - delay = 1000 - } - timestampFrame(p.term, e.GetString("time")) - time.Sleep(time.Millisecond * delay) - } - prev = ms - } + // playback goroutine: go func() { + var i int + defer func() { - p.state = stateStopped + p.Lock() + p.setState(stateStopped) + p.Unlock() + + // played last event? + if i == len(p.sessionEvents) { + p.stopOnce.Do(func() { close(p.stopC) }) + } }() - p.state = statePlaying - i, offset, bytes := 0, 0, 0 + + p.Lock() + p.setState(statePlaying) + p.Unlock() + + prev := time.Duration(0) + offset, bytes := 0, 0 for i = 0; i < to; i++ { - if p.state == stateStopping { + if p.stopRequested() { return } + e := p.sessionEvents[i] switch e.GetString(events.EventType) { // 'print' event (output) case events.SessionPrintEvent: - wait(i, e) + // delay is only necessary once we've caught up to the "from" event + if i >= from { + prev = p.applyDelay(prev, e) + } offset = e.GetInt("offset") bytes = e.GetInt("bytes") os.Stdout.Write(p.stream[offset : offset+bytes]) @@ -211,11 +241,32 @@ func (p *sessionPlayer) playRange(from, to int) { default: continue } + p.Lock() p.position = i - } - // played last event? - if i == len(p.sessionEvents) { - p.Stop() + p.Unlock() } }() } + +// applyDelay waits until it is time to play back the current event. +// It returns the duration from the start of the session up until the current event. +func (p *sessionPlayer) applyDelay(previousTimestamp time.Duration, e events.EventFields) time.Duration { + eventTime := time.Duration(e.GetInt("ms") * int(time.Millisecond)) + delay := eventTime - previousTimestamp + + // make playback smoother: + switch { + case delay < 10*time.Millisecond: + delay = 0 + case delay > 250*time.Millisecond && delay < 500*time.Millisecond: + delay = 250 * time.Millisecond + case delay > 500*time.Millisecond && delay < 1*time.Second: + delay = 500 * time.Millisecond + case delay > time.Second: + delay = time.Second + } + + timestampFrame(p.term, e.GetString("time")) + p.clock.Sleep(delay) + return eventTime +} diff --git a/lib/client/player_test.go b/lib/client/player_test.go new file mode 100644 index 0000000000000..8abfb99afa184 --- /dev/null +++ b/lib/client/player_test.go @@ -0,0 +1,158 @@ +/* +Copyright 2022 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "bytes" + "testing" + "time" + + "github.com/gravitational/teleport/lib/client/terminal" + "github.com/gravitational/teleport/lib/events" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +// TestEmptyPlay verifies that a playback of 0 events +// immediately transitions to a stopped state. +func TestEmptyPlay(t *testing.T) { + c := clockwork.NewFakeClock() + p := newSessionPlayer(nil, nil, testTerm(t)) + p.clock = c + + p.Play() + + select { + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for player to complete") + case <-p.stopC: + } + + require.True(t, p.Stopped()) +} + +// TestStop verifies that we can stop playback. +func TestStop(t *testing.T) { + c := clockwork.NewFakeClock() + events := printEvents(100, 200) + p := newSessionPlayer(events, nil, testTerm(t)) + p.clock = c + + p.Play() + + // wait for player to see the first event and apply the delay + c.BlockUntil(1) + + p.RequestStop() + + // advance the clock: + // at this point, the player will write the first event and then + // see that we requested a stop + c.Advance(100 * time.Millisecond) + + require.Eventually(t, p.Stopped, 2*time.Second, 200*time.Millisecond) +} + +// TestPlayPause verifies the play/pause functionality. +func TestPlayPause(t *testing.T) { + c := clockwork.NewFakeClock() + + // in this test, we let the player play 2 of the 3 events, + // then pause it and verify the pause state before resuming + // playback for the final event. + events := printEvents(100, 200, 300) + var stream []byte // intentionally empty, we dont care about stream contents here + p := newSessionPlayer(events, stream, testTerm(t)) + p.clock = c + + p.Play() + + // wait for player to see the first event and apply the delay + c.BlockUntil(1) + + // advance the clock: + // at this point, the player will write the first event + c.Advance(100 * time.Millisecond) + + // wait for the player to sleep on the 2nd event + c.BlockUntil(1) + + // pause playback + // note: we don't use p.TogglePause here, as it waits for the state transition, + // and the state won't transition proceed until we advance the clock + p.Lock() + p.setState(stateStopping) + p.Unlock() + + // advance the clock again: + // the player will write the second event and + // then realize that it's been asked to pause + c.Advance(100 * time.Millisecond) + + p.Lock() + p.waitUntil(stateStopped) + p.Unlock() + + ch := make(chan struct{}) + go func() { + // resume playback + p.TogglePause() + ch <- struct{}{} + }() + + // playback should resume for the 3rd and final event: + // in this case, the first two events are written immediately without delay, + // and we block here until the player is sleeping prior to the 3rd event + c.BlockUntil(1) + + // make sure that we've resumed + <-ch + require.False(t, p.Stopped()) + + // advance the clock a final time, forcing the player to write the last event + // note: on the resume, we play the successful events immediately, and then sleep + // up to the resume point, which is why we advance by 300ms here + c.Advance(300 * time.Millisecond) + + select { + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for player to complete") + case <-p.stopC: + } + require.True(t, p.Stopped()) +} + +func testTerm(t *testing.T) *terminal.Terminal { + t.Helper() + term, err := terminal.New(bytes.NewReader(nil), &bytes.Buffer{}, &bytes.Buffer{}) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, term.Close()) + }) + return term +} + +func printEvents(delays ...int) []events.EventFields { + result := make([]events.EventFields, len(delays)) + for i := range result { + result[i] = events.EventFields{ + events.EventType: events.SessionPrintEvent, + "ms": delays[i], + } + } + return result +}