Skip to content

Commit

Permalink
Add AV1 support to client implementation (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
streamer45 authored Jul 30, 2024
1 parent 8b783d9 commit e15d468
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 24 deletions.
9 changes: 5 additions & 4 deletions client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (c *Client) StartScreenShare(tracks []webrtc.TrackLocal) (*webrtc.RTPTransc
}
}

c.screenTransceiver = trx
c.screenTransceivers = append(c.screenTransceivers, trx)

sender := trx.Sender()

Expand All @@ -131,13 +131,14 @@ func (c *Client) StopScreenShare() error {
c.mut.Lock()
defer c.mut.Unlock()

if c.screenTransceiver != nil {
if err := c.pc.RemoveTrack(c.screenTransceiver.Sender()); err != nil {
for _, trx := range c.screenTransceivers {
if err := c.pc.RemoveTrack(trx.Sender()); err != nil {
return fmt.Errorf("failed to remove track: %w", err)
}
c.screenTransceiver = nil
}

c.screenTransceivers = nil

return c.sendWS(wsEventScreenOff, nil, false)
}

Expand Down
154 changes: 151 additions & 3 deletions client/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ func TestAPIScreenShare(t *testing.T) {
require.NoError(t, err)

t.Run("not initialized", func(t *testing.T) {
_, err := th.userClient.StartScreenShare([]webrtc.TrackLocal{th.newScreenTrack()})
_, err := th.userClient.StartScreenShare([]webrtc.TrackLocal{th.newScreenTrack(webrtc.MimeTypeVP8)})
require.EqualError(t, err, "rtc client is not initialized")
})

Expand Down Expand Up @@ -533,7 +533,7 @@ func TestAPIScreenShare(t *testing.T) {
// Test logic

// User screen shares, admin should receive the track
userScreenTrack := th.newScreenTrack()
userScreenTrack := th.newScreenTrack(webrtc.MimeTypeVP8)
_, err = th.userClient.StartScreenShare([]webrtc.TrackLocal{userScreenTrack})
require.NoError(t, err)
go th.screenTrackWriter(userScreenTrack, userCloseCh)
Expand Down Expand Up @@ -623,6 +623,154 @@ func TestAPIScreenShare(t *testing.T) {
}
}

func TestAPIScreenShareAV1(t *testing.T) {
th := setupTestHelper(t, "calls0")

th.userClient.cfg.EnableAV1 = true
th.adminClient.cfg.EnableAV1 = true

// Setup
userConnectCh := make(chan struct{})
err := th.userClient.On(RTCConnectEvent, func(_ any) error {
close(userConnectCh)
return nil
})
require.NoError(t, err)

adminConnectCh := make(chan struct{})
err = th.adminClient.On(RTCConnectEvent, func(_ any) error {
close(adminConnectCh)
return nil
})
require.NoError(t, err)

t.Run("not initialized", func(t *testing.T) {
_, err := th.userClient.StartScreenShare([]webrtc.TrackLocal{th.newScreenTrack(webrtc.MimeTypeAV1)})
require.EqualError(t, err, "rtc client is not initialized")
})

go func() {
err := th.userClient.Connect()
require.NoError(t, err)
}()

go func() {
err := th.adminClient.Connect()
require.NoError(t, err)
}()

select {
case <-userConnectCh:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for user connect event")
}

select {
case <-adminConnectCh:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for admin connect event")
}

userCloseCh := make(chan struct{})
adminCloseCh := make(chan struct{})

// Test logic

// User screen shares, admin should receive the track
userScreenTrack := th.newScreenTrack(webrtc.MimeTypeAV1)
_, err = th.userClient.StartScreenShare([]webrtc.TrackLocal{userScreenTrack})
require.NoError(t, err)
go th.screenTrackWriter(userScreenTrack, userCloseCh)

screenTrackCh := make(chan struct{})
err = th.adminClient.On(RTCTrackEvent, func(ctx any) error {
m := ctx.(map[string]any)
track := m["track"].(*webrtc.TrackRemote)
if track.Codec().MimeType == webrtc.MimeTypeAV1 {
close(screenTrackCh)
}
return nil
})
require.NoError(t, err)

userScreenOnCh := make(chan struct{})
err = th.adminClient.On(WSCallScreenOnEvent, func(ctx any) error {
sessionID := ctx.(string)
if sessionID == th.userClient.originalConnID {
close(userScreenOnCh)
}
return nil
})
require.NoError(t, err)

select {
case <-userScreenOnCh:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for user screen on event")
}

select {
case <-screenTrackCh:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for screen track")
}

userScreenOffCh := make(chan struct{})
err = th.adminClient.On(WSCallScreenOffEvent, func(ctx any) error {
sessionID := ctx.(string)
if sessionID == th.userClient.originalConnID {
close(userScreenOffCh)
}
return nil
})
require.NoError(t, err)

err = th.userClient.StopScreenShare()
require.NoError(t, err)

select {
case <-userScreenOffCh:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for user screen off event")
}

// Teardown

err = th.userClient.On(CloseEvent, func(_ any) error {
close(userCloseCh)
return nil
})
require.NoError(t, err)

err = th.adminClient.On(CloseEvent, func(_ any) error {
close(adminCloseCh)
return nil
})
require.NoError(t, err)

go func() {
err := th.userClient.Close()
require.NoError(t, err)
}()

go func() {
err := th.adminClient.Close()
require.NoError(t, err)
}()

select {
case <-userCloseCh:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for close event")
}

select {
case <-adminCloseCh:
case <-time.After(waitTimeout):
require.Fail(t, "timed out waiting for close event")
}
}

func TestAPIConcurrency(t *testing.T) {
t.Run("Mute/Unmute", func(t *testing.T) {
th := setupTestHelper(t, "calls0")
Expand Down Expand Up @@ -746,7 +894,7 @@ func TestAPIScreenShareAndVoice(t *testing.T) {
// Test logic

// User screen shares, admin should receive the track
userScreenTrack := th.newScreenTrack()
userScreenTrack := th.newScreenTrack(webrtc.MimeTypeVP8)
_, err = th.userClient.StartScreenShare([]webrtc.TrackLocal{userScreenTrack})
require.NoError(t, err)
go th.screenTrackWriter(userScreenTrack, userCloseCh)
Expand Down
5 changes: 3 additions & 2 deletions client/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ import (

func (c *Client) joinCall() error {
if err := c.SendWS(wsEventJoin, CallJoinMessage{
ChannelID: c.cfg.ChannelID,
JobID: c.cfg.JobID,
ChannelID: c.cfg.ChannelID,
JobID: c.cfg.JobID,
AV1Support: c.cfg.EnableAV1,
}, false); err != nil {
return fmt.Errorf("failed to send ws msg: %w", err)
}
Expand Down
12 changes: 6 additions & 6 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ type Client struct {
currentConnID string

// WebRTC
pc *webrtc.PeerConnection
dc *webrtc.DataChannel
iceCh chan webrtc.ICECandidateInit
receivers map[string][]*webrtc.RTPReceiver
voiceSender *webrtc.RTPSender
screenTransceiver *webrtc.RTPTransceiver
pc *webrtc.PeerConnection
dc *webrtc.DataChannel
iceCh chan webrtc.ICECandidateInit
receivers map[string][]*webrtc.RTPReceiver
voiceSender *webrtc.RTPSender
screenTransceivers []*webrtc.RTPTransceiver

state int32

Expand Down
3 changes: 3 additions & 0 deletions client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ type Config struct {
// JobID is an id used to identify bot initiated sessions (e.g.
// recording/transcription)
JobID string
// EnableAV1 controls whether the client should advertise support
// for receiving the AV1 codec.
EnableAV1 bool

wsURL string
}
Expand Down
22 changes: 15 additions & 7 deletions client/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ const (
waitTimeout = 5 * time.Second
)

func (th *TestHelper) newScreenTrack() *webrtc.TrackLocalStaticRTP {
func (th *TestHelper) newScreenTrack(mimeType string) *webrtc.TrackLocalStaticRTP {
th.tb.Helper()

track, err := webrtc.NewTrackLocalStaticRTP(webrtc.RTPCodecCapability{
MimeType: "video/VP8",
MimeType: mimeType,
ClockRate: 90000,
SDPFmtpLine: "",
RTCPFeedback: []webrtc.RTCPFeedback{
Expand All @@ -68,19 +68,27 @@ func (th *TestHelper) newScreenTrack() *webrtc.TrackLocalStaticRTP {
}

func (th *TestHelper) screenTrackWriter(track *webrtc.TrackLocalStaticRTP, closeCh <-chan struct{}) {
var payloader rtp.Payloader
payloader = &codecs.VP8Payloader{
EnablePictureID: true,
}
filename := "../testfiles/video.ivf"
if track.Codec().MimeType == webrtc.MimeTypeAV1 {
payloader = &codecs.AV1Payloader{}
filename = "../testfiles/video_av1.ivf"
}

packetizer := rtp.NewPacketizer(
1200,
0,
0,
&codecs.VP8Payloader{
EnablePictureID: true,
},
payloader,
rtp.NewRandomSequencer(),
90000,
)

// Open a IVF file and start reading using our IVFReader
file, ivfErr := os.Open("../testfiles/video.ivf")
file, ivfErr := os.Open(filename)
if ivfErr != nil {
log.Fatalf(ivfErr.Error())
}
Expand Down Expand Up @@ -139,7 +147,7 @@ func (th *TestHelper) screenTrackWriter(track *webrtc.TrackLocalStaticRTP, close
func (th *TestHelper) transmitScreenTrack(c *Client) {
th.tb.Helper()

track := th.newScreenTrack()
track := th.newScreenTrack(webrtc.MimeTypeVP8)

sender, err := c.pc.AddTrack(track)
require.NoError(th.tb, err)
Expand Down
5 changes: 3 additions & 2 deletions client/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ package client
const pluginID = "com.mattermost.calls"

type CallJoinMessage struct {
ChannelID string `json:"channelID"`
JobID string `json:"jobID"`
ChannelID string `json:"channelID"`
JobID string `json:"jobID"`
AV1Support bool `json:"av1Support"`
}

type CallReconnectMessage struct {
Expand Down
Binary file added testfiles/video_av1.ivf
Binary file not shown.

0 comments on commit e15d468

Please sign in to comment.