Skip to content

Commit

Permalink
✨ (backend): add management command for video transcription
Browse files Browse the repository at this point in the history
As we want to transcript the whole catalog of videos, we need to add a
management command.
  • Loading branch information
kernicPanel committed Sep 18, 2024
1 parent c9d8dd6 commit 52157df
Show file tree
Hide file tree
Showing 4 changed files with 399 additions and 7 deletions.
54 changes: 54 additions & 0 deletions src/backend/marsha/core/management/commands/transcript_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Management command to transcript a video."""

from django.core.management import BaseCommand

from marsha.core import defaults
from marsha.core.models import TimedTextTrack, Video
from marsha.core.utils.transcript_utils import transcript


class Command(BaseCommand):
"""Transcript a video."""

help = "Transcript a video"

def add_arguments(self, parser):
parser.add_argument("--video-id", type=str)

def handle(self, *args, **options):
"""Selects a video to transcript and starts the transcription job."""
video_id = options["video_id"]
if video_id:
try:
video = Video.objects.get(id=video_id)
except Video.DoesNotExist:
self.stdout.write(f"No video matches the provided id: {video_id}")
return

if video.upload_state != defaults.READY:
self.stdout.write(f"Video {video_id} is not ready")
return

if video.timedtexttracks.filter(mode=TimedTextTrack.TRANSCRIPT).exists():
self.stdout.write(f"Transcript already exists for video {video_id}")
return
else:
excluded_timed_text_tracks = TimedTextTrack.objects.filter(
mode=TimedTextTrack.TRANSCRIPT
)
video = (
Video.objects.exclude(timedtexttracks__in=excluded_timed_text_tracks)
.filter(upload_state=defaults.READY)
.order_by("-created_on")
.first()
)
if not video:
self.stdout.write("No video to transcript")
return

try:
self.stdout.write(f"Try to transcript video {video.id}")
transcript(video)
self.stdout.write(f"Transcription job started for video {video.id}")
except Exception as e: # pylint: disable=broad-except
self.stderr.write(f"Error: {e}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
"""Test transcript_video command."""

from io import StringIO
from unittest.mock import patch

from django.core.management import call_command
from django.test import TestCase

from marsha.core import defaults
from marsha.core.factories import TimedTextTrackFactory, VideoFactory
from marsha.core.management.commands import transcript_video
from marsha.core.models import TimedTextTrack


@patch.object(transcript_video, "transcript")
class TranscriptVideoTestCase(TestCase):
"""
Test case for the transcript_video command.
"""

def setUp(self):
"""
Set up the test case with videos.
"""
self.stdout = StringIO()

def test_transcript_video_no_videos(self, mock_transcript):
"""
Should not call the transcript function if there is no video to transcript.
"""
call_command("transcript_video", stdout=self.stdout)

self.assertListEqual(
self.stdout.getvalue().splitlines(),
["No video to transcript"],
)
mock_transcript.assert_not_called()

def test_transcript_video_first_video(self, mock_transcript):
"""
Should call the transcript function with the first video to transcript.
"""
VideoFactory(upload_state=defaults.READY)
video = VideoFactory(upload_state=defaults.READY)

call_command("transcript_video", stdout=self.stdout)

self.assertListEqual(
self.stdout.getvalue().splitlines(),
[
f"Try to transcript video {video.id}",
f"Transcription job started for video {video.id}",
],
)
mock_transcript.assert_called_once_with(video)

def test_transcript_video_not_ready(self, mock_transcript):
"""
Should not call the transcript function if the video is not ready.
"""
VideoFactory(upload_state=defaults.PENDING)

call_command("transcript_video", stdout=self.stdout)

self.assertListEqual(
self.stdout.getvalue().splitlines(),
["No video to transcript"],
)
mock_transcript.assert_not_called()

def test_transcript_video_already_transcript(self, mock_transcript):
"""
Should not call the transcript function if the video already has a transcript.
"""
TimedTextTrackFactory(
video=VideoFactory(upload_state=defaults.READY),
mode=TimedTextTrack.TRANSCRIPT,
)

call_command("transcript_video", stdout=self.stdout)

self.assertListEqual(
self.stdout.getvalue().splitlines(),
["No video to transcript"],
)
mock_transcript.assert_not_called()

def test_transcript_video_deleted_transcript(self, mock_transcript):
"""
Should call the transcript function if the video has a deleted transcript.
"""
timed_text_track = TimedTextTrackFactory(
video=VideoFactory(upload_state=defaults.READY),
mode=TimedTextTrack.TRANSCRIPT,
)
timed_text_track.delete()
self.assertEqual(TimedTextTrack.objects.all(force_visibility=True).count(), 1)

call_command("transcript_video", stdout=self.stdout)

self.assertListEqual(
self.stdout.getvalue().splitlines(),
[
f"Try to transcript video {timed_text_track.video.id}",
f"Transcription job started for video {timed_text_track.video.id}",
],
)
mock_transcript.assert_called_once_with(timed_text_track.video)

def test_transcript_video_unknown_argument(self, mock_transcript):
"""
Should not call the transcript function if there is no video to transcript.
"""
call_command("transcript_video", stdout=self.stdout, video_id=1)

self.assertListEqual(
self.stdout.getvalue().splitlines(),
["No video matches the provided id: 1"],
)
mock_transcript.assert_not_called()

def test_transcript_video_argument(self, mock_transcript):
"""
Should call the transcript function with the video to transcript.
"""
VideoFactory(upload_state=defaults.READY)
video = VideoFactory(upload_state=defaults.READY)
VideoFactory(upload_state=defaults.READY)

call_command("transcript_video", stdout=self.stdout, video_id=video.id)

self.assertListEqual(
self.stdout.getvalue().splitlines(),
[
f"Try to transcript video {video.id}",
f"Transcription job started for video {video.id}",
],
)
mock_transcript.assert_called_once_with(video)

def test_transcript_video_argument_not_ready(self, mock_transcript):
"""
Should not call the transcript function if the video is not ready.
"""
video = VideoFactory(upload_state=defaults.PENDING)

call_command("transcript_video", stdout=self.stdout, video_id=video.id)

self.assertListEqual(
self.stdout.getvalue().splitlines(),
[f"Video {video.id} is not ready"],
)
mock_transcript.assert_not_called()

def test_transcript_video_argument_already_transcript(self, mock_transcript):
"""
Should not call the transcript function if the video already has a transcript.
"""
timed_text_track = TimedTextTrackFactory(
video=VideoFactory(upload_state=defaults.READY),
mode=TimedTextTrack.TRANSCRIPT,
)

call_command(
"transcript_video", stdout=self.stdout, video_id=timed_text_track.video.id
)

self.assertListEqual(
self.stdout.getvalue().splitlines(),
[f"Transcript already exists for video {timed_text_track.video.id}"],
)
mock_transcript.assert_not_called()

def test_transcript_video_argument_deleted_transcript(self, mock_transcript):
"""
Should call the transcript function if the video has a deleted transcript.
"""
timed_text_track = TimedTextTrackFactory(
video=VideoFactory(upload_state=defaults.READY),
mode=TimedTextTrack.TRANSCRIPT,
)
timed_text_track.delete()
self.assertEqual(TimedTextTrack.objects.all(force_visibility=True).count(), 1)

call_command(
"transcript_video", stdout=self.stdout, video_id=timed_text_track.video.id
)

self.assertListEqual(
self.stdout.getvalue().splitlines(),
[
f"Try to transcript video {timed_text_track.video.id}",
f"Transcription job started for video {timed_text_track.video.id}",
],
)
mock_transcript.assert_called_once_with(timed_text_track.video)
110 changes: 103 additions & 7 deletions src/backend/marsha/core/tests/utils/test_transcript.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
"""Tests for the `core.utils.transcript` module."""

from unittest import mock
from unittest.mock import patch

from django.conf import settings
from django.core.files.uploadedfile import SimpleUploadedFile
from django.test import TestCase
from django.test import TestCase, override_settings

from django_peertube_runner_connector.models import (
Video as TranscriptedVideo,
VideoState,
)

from marsha.core import defaults
from marsha.core.factories import UploadedVideoFactory
from marsha.core.factories import (
TimedTextTrackFactory,
UploadedVideoFactory,
VideoFactory,
)
from marsha.core.models import TimedTextTrack
from marsha.core.storage.storage_class import video_storage
from marsha.core.utils import transcript_utils
from marsha.core.utils.time_utils import to_timestamp
from marsha.core.utils.transcript import transcription_ended_callback
from marsha.websocket.utils import channel_layers_utils


Expand All @@ -40,12 +45,14 @@ def test_transcription_ended_callback(self):
f"{video_path}/{video_timestamp}-{language}.vtt", vtt_file
)

with mock.patch.object(
with patch.object(
channel_layers_utils, "dispatch_timed_text_track"
) as mock_dispatch_timed_text_track, mock.patch.object(
) as mock_dispatch_timed_text_track, patch.object(
channel_layers_utils, "dispatch_video"
) as mock_dispatch_video:
transcription_ended_callback(transcripted_video, language, vtt_path)
transcript_utils.transcription_ended_callback(
transcripted_video, language, vtt_path
)

timed_text_track = video.timedtexttracks.get()
self.assertEqual(timed_text_track.language, language)
Expand All @@ -65,3 +72,92 @@ def test_transcription_ended_callback(self):

mock_dispatch_timed_text_track.assert_called_once_with(timed_text_track)
mock_dispatch_video.assert_called_once_with(video)

@patch.object(transcript_utils, "launch_video_transcript")
def test_transcript_video_no_video(self, mock_launch_video_transcript):
"""
Should not call the launch_video_transcript function
if there is no video to transcript.
"""

with self.assertRaises(transcript_utils.TranscriptError) as context:
transcript_utils.transcript(None)

self.assertEqual(str(context.exception), "No video to transcript")
mock_launch_video_transcript.delay.assert_not_called()

@patch.object(transcript_utils, "launch_video_transcript")
def test_transcript_video_already_transcript(self, mock_launch_video_transcript):
"""
Should not call the launch_video_transcript function
if the video already has a transcript.
"""
timed_text_track = TimedTextTrackFactory(
video=VideoFactory(upload_state=defaults.READY),
language=settings.LANGUAGES[0][0],
mode=TimedTextTrack.TRANSCRIPT,
)

with self.assertRaises(transcript_utils.TranscriptError) as context:
transcript_utils.transcript(timed_text_track.video)

self.assertEqual(
str(context.exception),
f"A transcript already exists for video {timed_text_track.video.id}",
)
mock_launch_video_transcript.delay.assert_not_called()

@patch.object(transcript_utils, "launch_video_transcript")
def test_transcript_video_peertube_pipeline(self, mock_launch_video_transcript):
"""
Should call the launch_video_transcript function
if the video pipeline is peertube.
"""
video = VideoFactory(transcode_pipeline=defaults.PEERTUBE_PIPELINE)

transcript_utils.transcript(video)

mock_launch_video_transcript.delay.assert_called_once_with(
video_pk=video.id,
stamp=video.uploaded_on_stamp(),
domain="https://example.com",
)
self.assertEqual(video.timedtexttracks.count(), 1)

@patch.object(transcript_utils, "launch_video_transcript")
def test_transcript_video_not_peertube_pipeline(self, mock_launch_video_transcript):
"""
Should call the launch_video_transcript function
if the video pipeline is not peertube.
"""
video = VideoFactory(transcode_pipeline=defaults.AWS_PIPELINE)

transcript_utils.transcript(video)

mock_launch_video_transcript.delay.assert_called_once_with(
video_pk=video.id,
stamp=video.uploaded_on_stamp(),
domain="https://example.com",
video_url=f"https://example.com/api/videos/{video.id}/transcript-source/",
)
self.assertEqual(video.timedtexttracks.count(), 1)

@patch.object(transcript_utils, "launch_video_transcript")
@override_settings(TRANSCODING_CALLBACK_DOMAIN="https://callback.com")
def test_transcript_video_callback_domain_setting(
self, mock_launch_video_transcript
):
"""
Should call the launch_video_transcript function
with the callback domain setting.
"""
video = VideoFactory(transcode_pipeline=defaults.PEERTUBE_PIPELINE)

transcript_utils.transcript(video)

mock_launch_video_transcript.delay.assert_called_once_with(
video_pk=video.id,
stamp=video.uploaded_on_stamp(),
domain="https://callback.com",
)
self.assertEqual(video.timedtexttracks.count(), 1)
Loading

0 comments on commit 52157df

Please sign in to comment.