From 18e31fb2358bad4e15bafa64b4e37301f9356eef Mon Sep 17 00:00:00 2001 From: streamer45 Date: Thu, 25 Jul 2024 11:19:10 +0200 Subject: [PATCH] Prevent failure if track context has no samples --- cmd/transcriber/call/tracks.go | 29 +++-- cmd/transcriber/call/transcriber_test.go | 133 ++++++++++++++++++++--- 2 files changed, 140 insertions(+), 22 deletions(-) diff --git a/cmd/transcriber/call/tracks.go b/cmd/transcriber/call/tracks.go index 2b8d604..6872e47 100644 --- a/cmd/transcriber/call/tracks.go +++ b/cmd/transcriber/call/tracks.go @@ -96,6 +96,9 @@ func (t *Transcriber) processLiveTrack(track trackRemote, sessionID string) { ctx.user = user ctx.filename = filepath.Join(getDataDir(), fmt.Sprintf("%s_%s.ogg", user.Id, track.ID())) + var prevArrivalTime time.Time + var prevRTPTimestamp uint32 + slog.Debug("processing voice track", slog.String("username", user.Username), slog.String("sessionID", sessionID), @@ -103,11 +106,18 @@ func (t *Transcriber) processLiveTrack(track trackRemote, sessionID string) { slog.Debug("start reading loop for track", slog.String("trackID", ctx.trackID)) defer func() { slog.Debug("exiting reading loop for track", slog.String("trackID", ctx.trackID)) - select { - case t.trackCtxs <- ctx: - default: - slog.Error("failed to enqueue track context", slog.Any("ctx", ctx)) + + // Only send the track context if we processed at least one audio packet. + if !prevArrivalTime.IsZero() { + select { + case t.trackCtxs <- ctx: + default: + slog.Error("failed to enqueue track context", slog.Any("ctx", ctx)) + } + } else { + slog.Debug("nothing to send", slog.String("trackID", ctx.trackID)) } + t.liveTracksWg.Done() }() @@ -131,8 +141,6 @@ func (t *Transcriber) processLiveTrack(track trackRemote, sessionID string) { } // Read track audio: - var prevArrivalTime time.Time - var prevRTPTimestamp uint32 for { pkt, _, readErr := track.ReadRTP() if readErr != nil { @@ -181,7 +189,7 @@ func (t *Transcriber) processLiveTrack(track trackRemote, sessionID string) { } var gap uint64 - if ctx.startTS == 0 { + if prevArrivalTime.IsZero() { ctx.startTS = time.Since(*t.startTime.Load()).Milliseconds() slog.Debug("start offset for track", slog.Duration("offset", time.Duration(ctx.startTS)*time.Millisecond), @@ -363,6 +371,7 @@ func (ctx trackContext) decodeAudio() ([]trackTimedSamples, error) { if err != nil { slog.Error("failed to decode audio data", slog.String("err", err.Error()), + slog.Any("data", data), slog.String("trackID", ctx.trackID)) } @@ -420,6 +429,12 @@ func (t *Transcriber) transcribeTrack(ctx trackContext) (transcribe.TrackTranscr var speechSamples []trackTimedSamples for _, ts := range samples { + if len(ts.pcm) == 0 { + slog.Warn("unexpected empty audio samples", + slog.String("trackID", ctx.trackID)) + continue + } + // We need to reset the speech detector's state from one chunk of samples // to the next. if err := sd.Reset(); err != nil { diff --git a/cmd/transcriber/call/transcriber_test.go b/cmd/transcriber/call/transcriber_test.go index f83a23b..b42d22f 100644 --- a/cmd/transcriber/call/transcriber_test.go +++ b/cmd/transcriber/call/transcriber_test.go @@ -48,6 +48,16 @@ func setupTranscriberForTest(t *testing.T) *Transcriber { require.NoError(t, err) require.NotNil(t, tr) + dir, err := os.MkdirTemp("", "data") + if err != nil { + require.NoError(t, err) + } + os.Setenv("DATA_DIR", dir) + t.Cleanup(func() { + os.Unsetenv("DATA_DIR") + os.RemoveAll(dir) + }) + return tr } @@ -176,10 +186,6 @@ func TestProcessLiveTrack(t *testing.T) { sessionID := "sessionID" - dataDir := os.Getenv("DATA_DIR") - os.Setenv("DATA_DIR", os.TempDir()) - defer os.Setenv("DATA_DIR", dataDir) - tr.liveTracksWg.Add(1) tr.startTime.Store(newTimeP(time.Now().Add(-time.Second))) tr.processLiveTrack(track, sessionID) @@ -281,10 +287,6 @@ func TestProcessLiveTrack(t *testing.T) { sessionID := "sessionID" - dataDir := os.Getenv("DATA_DIR") - os.Setenv("DATA_DIR", os.TempDir()) - defer os.Setenv("DATA_DIR", dataDir) - tr.liveTracksWg.Add(1) tr.startTime.Store(newTimeP(time.Now().Add(-time.Second))) tr.processLiveTrack(track, sessionID) @@ -385,10 +387,6 @@ func TestProcessLiveTrack(t *testing.T) { sessionID := "sessionID" - dataDir := os.Getenv("DATA_DIR") - os.Setenv("DATA_DIR", os.TempDir()) - defer os.Setenv("DATA_DIR", dataDir) - tr.liveTracksWg.Add(1) tr.startTime.Store(newTimeP(time.Now().Add(-time.Second))) tr.processLiveTrack(track, sessionID) @@ -450,12 +448,117 @@ func TestProcessLiveTrack(t *testing.T) { Body: io.NopCloser(strings.NewReader(`{"id": "userID", "username": "testuser"}`)), }, nil).Once() + track := &trackRemoteMock{ + id: "trackID", + } + + pkts := []*rtp.Packet{ + { + Header: rtp.Header{ + Timestamp: 1000, + }, + Payload: []byte{0x45, 0x45, 0x45}, + }, + { + Header: rtp.Header{ + Timestamp: 2000, + }, + Payload: []byte{0x45, 0x45, 0x45}, + }, + { + Header: rtp.Header{ + Timestamp: 3000, + }, + Payload: []byte{0x45, 0x45, 0x45}, + }, + { + Header: rtp.Header{ + Timestamp: 4000, + }, + Payload: []byte{0x45, 0x45, 0x45}, + }, + } + + var i int + track.readRTP = func() (*rtp.Packet, interceptor.Attributes, error) { + if i >= len(pkts) { + return nil, nil, io.EOF + } + defer func() { i++ }() + return pkts[i], nil, nil + } + tr.liveTracksWg.Add(1) tr.startTime.Store(newTimeP(time.Now().Add(-time.Second))) - tr.processLiveTrack(&trackRemoteMock{ - id: "trackID", - }, "sessionID") + tr.processLiveTrack(track, "sessionID") + close(tr.trackCtxs) require.Len(t, tr.trackCtxs, 1) }) + + t.Run("should not queue contexes with no samples", func(t *testing.T) { + tr := setupTranscriberForTest(t) + + mockClient := &mocks.MockAPIClient{} + tr.apiClient = mockClient + + defer mockClient.AssertExpectations(t) + + mockClient.On("DoAPIRequest", mock.Anything, http.MethodGet, + "http://localhost:8065/plugins/com.mattermost.calls/bot/calls/8w8jorhr7j83uqr6y1st894hqe/sessions/sessionID/profile", "", ""). + Return(&http.Response{ + Body: io.NopCloser(strings.NewReader(`{"id": "userID", "username": "testuser"}`)), + }, nil).Once() + + track := &trackRemoteMock{ + id: "trackID", + } + + pkts := []*rtp.Packet{ + { + Header: rtp.Header{ + Timestamp: 1000, + }, + Payload: []byte{0x45, 0x45, 0x45}, + }, + { + Header: rtp.Header{ + Timestamp: 2000, + }, + Payload: []byte{0x45, 0x45, 0x45}, + }, + { + Header: rtp.Header{ + Timestamp: 3000, + }, + Payload: []byte{0x45, 0x45, 0x45}, + }, + { + Header: rtp.Header{ + Timestamp: 4000, + }, + Payload: []byte{0x45, 0x45, 0x45}, + }, + } + + var i int + track.readRTP = func() (*rtp.Packet, interceptor.Attributes, error) { + if i >= len(pkts) { + return nil, nil, io.EOF + } + + defer func() { i++ }() + + if i == 3 { + time.Sleep(2 * time.Second) + } + + return pkts[i], nil, nil + } + + tr.liveTracksWg.Add(1) + tr.processLiveTrack(track, "sessionID") + close(tr.trackCtxs) + require.Empty(t, tr.trackCtxs) + }) }