diff --git a/backend/rag_solution/router/podcast_router.py b/backend/rag_solution/router/podcast_router.py index 183b9c61..713e440a 100644 --- a/backend/rag_solution/router/podcast_router.py +++ b/backend/rag_solution/router/podcast_router.py @@ -4,14 +4,17 @@ Provides RESTful API for podcast generation, status checking, and management. """ +import io import logging from typing import Annotated from core.config import Settings, get_settings -from fastapi import APIRouter, BackgroundTasks, Depends, Query +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query +from fastapi.responses import StreamingResponse from pydantic import UUID4 from sqlalchemy.ext.asyncio import AsyncSession +from rag_solution.core.dependencies import get_current_user from rag_solution.file_management.database import get_db from rag_solution.schemas.podcast_schema import ( PodcastGenerationInput, @@ -227,3 +230,44 @@ async def delete_podcast( HTTPException 403: Access denied """ await podcast_service.delete_podcast(podcast_id, user_id) + + +# Valid voice IDs for OpenAI TTS voices +VALID_VOICE_IDS = {"alloy", "echo", "fable", "onyx", "nova", "shimmer"} + + +@router.get( + "/voice-preview/{voice_id}", + summary="Get a voice preview", + description="Generates and returns a short audio preview for a given voice ID.", + response_class=StreamingResponse, +) +async def get_voice_preview( + voice_id: str, + podcast_service: Annotated[PodcastService, Depends(get_podcast_service)], + current_user: dict = Depends(get_current_user), +) -> StreamingResponse: + """ + Get a voice preview. + + Args: + voice_id: The ID of the voice to preview. Must be one of: alloy, echo, fable, onyx, nova, shimmer. + podcast_service: Injected podcast service. + current_user: Authenticated user (required for access control). + + Returns: + A streaming response with the audio preview. + + Raises: + HTTPException 400: Invalid voice_id provided. + HTTPException 500: Failed to generate voice preview. + """ + # Validate voice_id + if voice_id not in VALID_VOICE_IDS: + raise HTTPException( + status_code=400, + detail=f"Invalid voice_id '{voice_id}'. Must be one of: {', '.join(sorted(VALID_VOICE_IDS))}", + ) + + audio_bytes = await podcast_service.generate_voice_preview(voice_id) + return StreamingResponse(io.BytesIO(audio_bytes), media_type="audio/mpeg") diff --git a/backend/rag_solution/services/podcast_service.py b/backend/rag_solution/services/podcast_service.py index 8376c74e..a8635fe7 100644 --- a/backend/rag_solution/services/podcast_service.py +++ b/backend/rag_solution/services/podcast_service.py @@ -80,6 +80,9 @@ class PodcastService: Generate the complete dialogue script now:""" + # Voice preview text for TTS samples + VOICE_PREVIEW_TEXT = "Hello, you are listening to a preview of this voice." + def __init__( self, session: AsyncSession, @@ -592,3 +595,38 @@ async def delete_podcast(self, podcast_id: UUID4, user_id: UUID4) -> bool: # Delete database record return await self.repository.delete(podcast_id) + + async def generate_voice_preview(self, voice_id: str) -> bytes: + """ + Generate a short audio preview for a specific voice. + + Args: + voice_id: The ID of the voice to preview. + + Returns: + The audio data as bytes. + """ + try: + logger.info("Generating voice preview for voice_id: %s", voice_id) + + # Create audio provider + audio_provider = AudioProviderFactory.create_provider( + provider_type=self.settings.podcast_audio_provider, + settings=self.settings, + ) + + # Generate a short, generic audio preview + audio_bytes = await audio_provider.generate_single_turn_audio( + text=self.VOICE_PREVIEW_TEXT, + voice=voice_id, + audio_format=AudioFormat.MP3, + ) + + return audio_bytes + + except Exception as e: + logger.exception("Failed to generate voice preview for voice_id: %s", voice_id) + raise HTTPException( + status_code=500, + detail=f"Failed to generate voice preview: {e}", + ) from e diff --git a/backend/tests/unit/test_podcast_service_unit.py b/backend/tests/unit/test_podcast_service_unit.py index 2a42b35c..2ea025e1 100644 --- a/backend/tests/unit/test_podcast_service_unit.py +++ b/backend/tests/unit/test_podcast_service_unit.py @@ -191,7 +191,7 @@ def mock_service(self) -> PodcastService: ) @pytest.mark.asyncio - async def test_validate_podcast_input(self, mock_service: PodcastService) -> None: + async def test_validate_podcast_input(self) -> None: """Unit: Validates podcast input schema.""" podcast_input = PodcastGenerationInput( user_id=uuid4(), @@ -310,3 +310,95 @@ async def test_generate_script_uses_generic_topic_without_description( mock_llm_provider.generate_text.assert_called_once() prompt = mock_llm_provider.generate_text.call_args[1]["prompt"] assert "Topic/Focus: General overview of the content" in prompt + + +@pytest.mark.unit +class TestPodcastServiceVoicePreview: + """Unit tests for voice preview functionality.""" + + @pytest.fixture + def mock_service(self) -> PodcastService: + """Fixture: Create mock PodcastService.""" + session = Mock(spec=AsyncSession) + collection_service = Mock(spec=CollectionService) + search_service = Mock(spec=SearchService) + + service = PodcastService( + session=session, + collection_service=collection_service, + search_service=search_service, + ) + + return service + + @pytest.mark.asyncio + async def test_generate_voice_preview_success(self, mock_service: PodcastService) -> None: + """Unit: generate_voice_preview successfully generates audio.""" + voice_id = "alloy" + expected_audio = b"mock_audio_data" + + # Mock AudioProviderFactory + with patch("rag_solution.services.podcast_service.AudioProviderFactory") as mock_factory: + mock_provider = AsyncMock() + mock_provider.generate_single_turn_audio = AsyncMock(return_value=expected_audio) + mock_factory.create_provider.return_value = mock_provider + + # Call the method + audio_bytes = await mock_service.generate_voice_preview(voice_id) + + # Assertions + assert audio_bytes == expected_audio + mock_factory.create_provider.assert_called_once() + mock_provider.generate_single_turn_audio.assert_called_once_with( + text=mock_service.VOICE_PREVIEW_TEXT, + voice=voice_id, + audio_format=AudioFormat.MP3, + ) + + @pytest.mark.asyncio + async def test_generate_voice_preview_uses_constant_text(self, mock_service: PodcastService) -> None: + """Unit: generate_voice_preview uses VOICE_PREVIEW_TEXT constant.""" + voice_id = "onyx" + + with patch("rag_solution.services.podcast_service.AudioProviderFactory") as mock_factory: + mock_provider = AsyncMock() + mock_provider.generate_single_turn_audio = AsyncMock(return_value=b"audio") + mock_factory.create_provider.return_value = mock_provider + + await mock_service.generate_voice_preview(voice_id) + + # Verify constant is used + call_args = mock_provider.generate_single_turn_audio.call_args + assert call_args.kwargs["text"] == PodcastService.VOICE_PREVIEW_TEXT + + @pytest.mark.asyncio + async def test_generate_voice_preview_raises_on_provider_error(self, mock_service: PodcastService) -> None: + """Unit: generate_voice_preview raises HTTPException on provider error.""" + voice_id = "echo" + + with patch("rag_solution.services.podcast_service.AudioProviderFactory") as mock_factory: + mock_provider = AsyncMock() + mock_provider.generate_single_turn_audio = AsyncMock(side_effect=Exception("TTS API error")) + mock_factory.create_provider.return_value = mock_provider + + # Should raise HTTPException + with pytest.raises(Exception) as exc_info: + await mock_service.generate_voice_preview(voice_id) + + # Verify exception is raised + assert exc_info.type.__name__ in ["HTTPException", "Exception"] + + @pytest.mark.asyncio + async def test_generate_voice_preview_all_valid_voices(self, mock_service: PodcastService) -> None: + """Unit: generate_voice_preview works with all valid OpenAI voices.""" + valid_voices = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"] + + with patch("rag_solution.services.podcast_service.AudioProviderFactory") as mock_factory: + mock_provider = AsyncMock() + mock_provider.generate_single_turn_audio = AsyncMock(return_value=b"audio") + mock_factory.create_provider.return_value = mock_provider + + # Test each voice + for voice_id in valid_voices: + audio_bytes = await mock_service.generate_voice_preview(voice_id) + assert audio_bytes == b"audio" diff --git a/frontend/src/components/podcasts/PodcastGenerationModal.tsx b/frontend/src/components/podcasts/PodcastGenerationModal.tsx index 31c5e202..4ebf16a3 100644 --- a/frontend/src/components/podcasts/PodcastGenerationModal.tsx +++ b/frontend/src/components/podcasts/PodcastGenerationModal.tsx @@ -1,7 +1,8 @@ -import React, { useState } from 'react'; +import React, { useState, useRef, useEffect } from 'react'; import { XMarkIcon } from '@heroicons/react/24/outline'; import { useNotification } from '../../contexts/NotificationContext'; -import apiClient, { PodcastGenerationInput } from '../../services/apiClient'; +import apiClient, { PodcastGenerationInput, VoiceId } from '../../services/apiClient'; +import VoiceSelector from './VoiceSelector'; interface PodcastGenerationModalProps { isOpen: boolean; @@ -11,7 +12,7 @@ interface PodcastGenerationModalProps { onPodcastCreated?: (podcastId: string) => void; } -const VOICE_OPTIONS = [ +const VOICE_OPTIONS: Array<{id: VoiceId; name: string; gender: 'male' | 'female' | 'neutral'; description: string}> = [ { id: 'alloy', name: 'Alloy', gender: 'neutral', description: 'Neutral, balanced voice' }, { id: 'echo', name: 'Echo', gender: 'male', description: 'Warm, articulate male voice' }, { id: 'fable', name: 'Fable', gender: 'neutral', description: 'Expressive, storytelling voice' }, @@ -53,6 +54,65 @@ const PodcastGenerationModal: React.FC = ({ const [includeOutro, setIncludeOutro] = useState(false); const [showAdvanced, setShowAdvanced] = useState(false); + const [playingVoiceId, setPlayingVoiceId] = useState(null); + const audioRef = useRef(null); + const audioUrlRef = useRef(null); + + const handlePlayPreview = async (voiceId: VoiceId) => { + if (playingVoiceId === voiceId) { + handleStopPreview(); + return; + } + + try { + const audioBlob = await apiClient.getVoicePreview(voiceId); + const audioUrl = URL.createObjectURL(audioBlob); + + // Clean up previous audio if exists + if (audioRef.current) { + audioRef.current.pause(); + audioRef.current.src = ''; + } + if (audioUrlRef.current) { + URL.revokeObjectURL(audioUrlRef.current); + } + + audioUrlRef.current = audioUrl; + audioRef.current = new Audio(audioUrl); + audioRef.current.play(); + setPlayingVoiceId(voiceId); + + audioRef.current.onended = () => { + setPlayingVoiceId(null); + }; + } catch (error) { + console.error('Error playing voice preview:', error); + const errorMessage = error instanceof Error ? error.message : 'Unknown error'; + addNotification('error', 'Preview Failed', `Could not load voice preview: ${errorMessage}`); + } + }; + + const handleStopPreview = () => { + if (audioRef.current) { + audioRef.current.pause(); + audioRef.current.src = ''; + audioRef.current = null; + } + if (audioUrlRef.current) { + URL.revokeObjectURL(audioUrlRef.current); + audioUrlRef.current = null; + } + setPlayingVoiceId(null); + }; + + useEffect(() => { + return () => { + // Cleanup audio on component unmount + handleStopPreview(); + }; + }, []); + + const selectedDuration = DURATION_OPTIONS.find(d => d.value === duration); const estimatedCost = selectedDuration?.cost || 0; @@ -183,38 +243,24 @@ const PodcastGenerationModal: React.FC = ({ {/* Voice Settings */}
-
- - -
-
- - -
+ +
{/* Advanced Options (Collapsible) */} diff --git a/frontend/src/components/podcasts/VoiceSelector.tsx b/frontend/src/components/podcasts/VoiceSelector.tsx new file mode 100644 index 00000000..e993c905 --- /dev/null +++ b/frontend/src/components/podcasts/VoiceSelector.tsx @@ -0,0 +1,80 @@ +import React from 'react'; +import { PlayIcon, PauseIcon } from '@heroicons/react/24/solid'; +import { VoiceId } from '../../services/apiClient'; + +interface VoiceOption { + id: VoiceId; + name: string; + gender: 'male' | 'female' | 'neutral'; + description: string; +} + +interface VoiceSelectorProps { + label: string; + options: VoiceOption[]; + selectedVoice: string; + onSelectVoice: (voiceId: string) => void; + playingVoiceId: string | null; + onPlayPreview: (voiceId: VoiceId) => void | Promise; + onStopPreview: () => void; +} + +const VoiceSelector: React.FC = ({ + label, + options, + selectedVoice, + onSelectVoice, + playingVoiceId, + onPlayPreview, + onStopPreview, +}) => { + return ( +
+ +
+ {options.map((voice) => { + const isSelected = selectedVoice === voice.id; + const isPlaying = playingVoiceId === voice.id; + + return ( +
onSelectVoice(voice.id)} + className={`flex items-center justify-between p-3 rounded-lg border-2 cursor-pointer transition-all ${ + isSelected + ? 'border-blue-50 bg-blue-50 bg-opacity-20' + : 'border-gray-30 hover:border-gray-40' + }`} + > +
+ +
+
{voice.name}
+
{voice.description}
+
+
+
+ ); + })} +
+
+ ); +}; + +export default VoiceSelector; \ No newline at end of file diff --git a/frontend/src/services/apiClient.ts b/frontend/src/services/apiClient.ts index 5b9484ef..cf1b42f2 100644 --- a/frontend/src/services/apiClient.ts +++ b/frontend/src/services/apiClient.ts @@ -2,6 +2,9 @@ import axios, { AxiosInstance, AxiosResponse } from 'axios'; const API_BASE_URL = process.env.REACT_APP_BACKEND_URL || ''; +// Valid OpenAI TTS voice IDs +type VoiceId = 'alloy' | 'echo' | 'fable' | 'onyx' | 'nova' | 'shimmer'; + interface SearchInput { question: string; collection_id: string; @@ -884,6 +887,16 @@ class ApiClient { ); return response.data; } + + async getVoicePreview(voiceId: VoiceId): Promise { + const response: AxiosResponse = await this.client.get( + `/api/podcasts/voice-preview/${voiceId}`, + { + responseType: 'blob', + } + ); + return response.data; + } } // Create singleton instance @@ -908,4 +921,5 @@ export type { PodcastQuestionInjection, VoiceSettings, PodcastStepDetails, + VoiceId, };