Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion backend/rag_solution/router/podcast_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
Provides RESTful API for podcast generation, status checking, and management.
"""

import io
import logging
from typing import Annotated

from core.config import Settings, get_settings
from fastapi import APIRouter, BackgroundTasks, Depends, Query
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query
from fastapi.responses import StreamingResponse
from pydantic import UUID4
from sqlalchemy.ext.asyncio import AsyncSession

from rag_solution.core.dependencies import get_current_user
from rag_solution.file_management.database import get_db
from rag_solution.schemas.podcast_schema import (
PodcastGenerationInput,
Expand Down Expand Up @@ -227,3 +230,44 @@ async def delete_podcast(
HTTPException 403: Access denied
"""
await podcast_service.delete_podcast(podcast_id, user_id)


# Valid voice IDs for OpenAI TTS voices
VALID_VOICE_IDS = {"alloy", "echo", "fable", "onyx", "nova", "shimmer"}


@router.get(
"/voice-preview/{voice_id}",
summary="Get a voice preview",
description="Generates and returns a short audio preview for a given voice ID.",
response_class=StreamingResponse,
)
async def get_voice_preview(
voice_id: str,
podcast_service: Annotated[PodcastService, Depends(get_podcast_service)],
current_user: dict = Depends(get_current_user),
) -> StreamingResponse:
"""
Get a voice preview.

Args:
voice_id: The ID of the voice to preview. Must be one of: alloy, echo, fable, onyx, nova, shimmer.
podcast_service: Injected podcast service.
current_user: Authenticated user (required for access control).

Returns:
A streaming response with the audio preview.

Raises:
HTTPException 400: Invalid voice_id provided.
HTTPException 500: Failed to generate voice preview.
"""
# Validate voice_id
if voice_id not in VALID_VOICE_IDS:
raise HTTPException(
status_code=400,
detail=f"Invalid voice_id '{voice_id}'. Must be one of: {', '.join(sorted(VALID_VOICE_IDS))}",
)

audio_bytes = await podcast_service.generate_voice_preview(voice_id)
return StreamingResponse(io.BytesIO(audio_bytes), media_type="audio/mpeg")
38 changes: 38 additions & 0 deletions backend/rag_solution/services/podcast_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ class PodcastService:

Generate the complete dialogue script now:"""

# Voice preview text for TTS samples
VOICE_PREVIEW_TEXT = "Hello, you are listening to a preview of this voice."

def __init__(
self,
session: AsyncSession,
Expand Down Expand Up @@ -592,3 +595,38 @@ async def delete_podcast(self, podcast_id: UUID4, user_id: UUID4) -> bool:

# Delete database record
return await self.repository.delete(podcast_id)

async def generate_voice_preview(self, voice_id: str) -> bytes:
"""
Generate a short audio preview for a specific voice.

Args:
voice_id: The ID of the voice to preview.

Returns:
The audio data as bytes.
"""
try:
logger.info("Generating voice preview for voice_id: %s", voice_id)

# Create audio provider
audio_provider = AudioProviderFactory.create_provider(
provider_type=self.settings.podcast_audio_provider,
settings=self.settings,
)

# Generate a short, generic audio preview
audio_bytes = await audio_provider.generate_single_turn_audio(
text=self.VOICE_PREVIEW_TEXT,
voice=voice_id,
audio_format=AudioFormat.MP3,
)

return audio_bytes

except Exception as e:
logger.exception("Failed to generate voice preview for voice_id: %s", voice_id)
raise HTTPException(
status_code=500,
detail=f"Failed to generate voice preview: {e}",
) from e
94 changes: 93 additions & 1 deletion backend/tests/unit/test_podcast_service_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def mock_service(self) -> PodcastService:
)

@pytest.mark.asyncio
async def test_validate_podcast_input(self, mock_service: PodcastService) -> None:
async def test_validate_podcast_input(self) -> None:
"""Unit: Validates podcast input schema."""
podcast_input = PodcastGenerationInput(
user_id=uuid4(),
Expand Down Expand Up @@ -310,3 +310,95 @@ async def test_generate_script_uses_generic_topic_without_description(
mock_llm_provider.generate_text.assert_called_once()
prompt = mock_llm_provider.generate_text.call_args[1]["prompt"]
assert "Topic/Focus: General overview of the content" in prompt


@pytest.mark.unit
class TestPodcastServiceVoicePreview:
"""Unit tests for voice preview functionality."""

@pytest.fixture
def mock_service(self) -> PodcastService:
"""Fixture: Create mock PodcastService."""
session = Mock(spec=AsyncSession)
collection_service = Mock(spec=CollectionService)
search_service = Mock(spec=SearchService)

service = PodcastService(
session=session,
collection_service=collection_service,
search_service=search_service,
)

return service

@pytest.mark.asyncio
async def test_generate_voice_preview_success(self, mock_service: PodcastService) -> None:
"""Unit: generate_voice_preview successfully generates audio."""
voice_id = "alloy"
expected_audio = b"mock_audio_data"

# Mock AudioProviderFactory
with patch("rag_solution.services.podcast_service.AudioProviderFactory") as mock_factory:
mock_provider = AsyncMock()
mock_provider.generate_single_turn_audio = AsyncMock(return_value=expected_audio)
mock_factory.create_provider.return_value = mock_provider

# Call the method
audio_bytes = await mock_service.generate_voice_preview(voice_id)

# Assertions
assert audio_bytes == expected_audio
mock_factory.create_provider.assert_called_once()
mock_provider.generate_single_turn_audio.assert_called_once_with(
text=mock_service.VOICE_PREVIEW_TEXT,
voice=voice_id,
audio_format=AudioFormat.MP3,
)

@pytest.mark.asyncio
async def test_generate_voice_preview_uses_constant_text(self, mock_service: PodcastService) -> None:
"""Unit: generate_voice_preview uses VOICE_PREVIEW_TEXT constant."""
voice_id = "onyx"

with patch("rag_solution.services.podcast_service.AudioProviderFactory") as mock_factory:
mock_provider = AsyncMock()
mock_provider.generate_single_turn_audio = AsyncMock(return_value=b"audio")
mock_factory.create_provider.return_value = mock_provider

await mock_service.generate_voice_preview(voice_id)

# Verify constant is used
call_args = mock_provider.generate_single_turn_audio.call_args
assert call_args.kwargs["text"] == PodcastService.VOICE_PREVIEW_TEXT

@pytest.mark.asyncio
async def test_generate_voice_preview_raises_on_provider_error(self, mock_service: PodcastService) -> None:
"""Unit: generate_voice_preview raises HTTPException on provider error."""
voice_id = "echo"

with patch("rag_solution.services.podcast_service.AudioProviderFactory") as mock_factory:
mock_provider = AsyncMock()
mock_provider.generate_single_turn_audio = AsyncMock(side_effect=Exception("TTS API error"))
mock_factory.create_provider.return_value = mock_provider

# Should raise HTTPException
with pytest.raises(Exception) as exc_info:
await mock_service.generate_voice_preview(voice_id)

# Verify exception is raised
assert exc_info.type.__name__ in ["HTTPException", "Exception"]

@pytest.mark.asyncio
async def test_generate_voice_preview_all_valid_voices(self, mock_service: PodcastService) -> None:
"""Unit: generate_voice_preview works with all valid OpenAI voices."""
valid_voices = ["alloy", "echo", "fable", "onyx", "nova", "shimmer"]

with patch("rag_solution.services.podcast_service.AudioProviderFactory") as mock_factory:
mock_provider = AsyncMock()
mock_provider.generate_single_turn_audio = AsyncMock(return_value=b"audio")
mock_factory.create_provider.return_value = mock_provider

# Test each voice
for voice_id in valid_voices:
audio_bytes = await mock_service.generate_voice_preview(voice_id)
assert audio_bytes == b"audio"
116 changes: 81 additions & 35 deletions frontend/src/components/podcasts/PodcastGenerationModal.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import React, { useState } from 'react';
import React, { useState, useRef, useEffect } from 'react';
import { XMarkIcon } from '@heroicons/react/24/outline';
import { useNotification } from '../../contexts/NotificationContext';
import apiClient, { PodcastGenerationInput } from '../../services/apiClient';
import apiClient, { PodcastGenerationInput, VoiceId } from '../../services/apiClient';
import VoiceSelector from './VoiceSelector';

interface PodcastGenerationModalProps {
isOpen: boolean;
Expand All @@ -11,7 +12,7 @@ interface PodcastGenerationModalProps {
onPodcastCreated?: (podcastId: string) => void;
}

const VOICE_OPTIONS = [
const VOICE_OPTIONS: Array<{id: VoiceId; name: string; gender: 'male' | 'female' | 'neutral'; description: string}> = [
{ id: 'alloy', name: 'Alloy', gender: 'neutral', description: 'Neutral, balanced voice' },
{ id: 'echo', name: 'Echo', gender: 'male', description: 'Warm, articulate male voice' },
{ id: 'fable', name: 'Fable', gender: 'neutral', description: 'Expressive, storytelling voice' },
Expand Down Expand Up @@ -53,6 +54,65 @@ const PodcastGenerationModal: React.FC<PodcastGenerationModalProps> = ({
const [includeOutro, setIncludeOutro] = useState(false);
const [showAdvanced, setShowAdvanced] = useState(false);

const [playingVoiceId, setPlayingVoiceId] = useState<string | null>(null);
const audioRef = useRef<HTMLAudioElement | null>(null);
const audioUrlRef = useRef<string | null>(null);

const handlePlayPreview = async (voiceId: VoiceId) => {
if (playingVoiceId === voiceId) {
handleStopPreview();
return;
}

try {
const audioBlob = await apiClient.getVoicePreview(voiceId);
const audioUrl = URL.createObjectURL(audioBlob);

// Clean up previous audio if exists
if (audioRef.current) {
audioRef.current.pause();
audioRef.current.src = '';
}
if (audioUrlRef.current) {
URL.revokeObjectURL(audioUrlRef.current);
}

audioUrlRef.current = audioUrl;
audioRef.current = new Audio(audioUrl);
audioRef.current.play();
setPlayingVoiceId(voiceId);

audioRef.current.onended = () => {
setPlayingVoiceId(null);
};
} catch (error) {
console.error('Error playing voice preview:', error);
const errorMessage = error instanceof Error ? error.message : 'Unknown error';
addNotification('error', 'Preview Failed', `Could not load voice preview: ${errorMessage}`);
}
};

const handleStopPreview = () => {
if (audioRef.current) {
audioRef.current.pause();
audioRef.current.src = '';
audioRef.current = null;
}
if (audioUrlRef.current) {
URL.revokeObjectURL(audioUrlRef.current);
audioUrlRef.current = null;
}
setPlayingVoiceId(null);
};

useEffect(() => {
return () => {
// Cleanup audio on component unmount
handleStopPreview();
};
}, []);


const selectedDuration = DURATION_OPTIONS.find(d => d.value === duration);
const estimatedCost = selectedDuration?.cost || 0;

Expand Down Expand Up @@ -183,38 +243,24 @@ const PodcastGenerationModal: React.FC<PodcastGenerationModalProps> = ({

{/* Voice Settings */}
<div className="grid grid-cols-2 gap-4">
<div>
<label className="block text-sm font-medium text-white mb-2">
Host Voice
</label>
<select
value={hostVoice}
onChange={(e) => setHostVoice(e.target.value)}
className="w-full px-4 py-2 bg-gray-90 border border-gray-30 rounded-lg text-white focus:outline-none focus:border-blue-50"
>
{VOICE_OPTIONS.map((voice) => (
<option key={voice.id} value={voice.id}>
{voice.name} - {voice.description}
</option>
))}
</select>
</div>
<div>
<label className="block text-sm font-medium text-white mb-2">
Expert Voice
</label>
<select
value={expertVoice}
onChange={(e) => setExpertVoice(e.target.value)}
className="w-full px-4 py-2 bg-gray-90 border border-gray-30 rounded-lg text-white focus:outline-none focus:border-blue-50"
>
{VOICE_OPTIONS.map((voice) => (
<option key={voice.id} value={voice.id}>
{voice.name} - {voice.description}
</option>
))}
</select>
</div>
<VoiceSelector
label="Host Voice"
options={VOICE_OPTIONS}
selectedVoice={hostVoice}
onSelectVoice={setHostVoice}
playingVoiceId={playingVoiceId}
onPlayPreview={handlePlayPreview}
onStopPreview={handleStopPreview}
/>
<VoiceSelector
label="Expert Voice"
options={VOICE_OPTIONS}
selectedVoice={expertVoice}
onSelectVoice={setExpertVoice}
playingVoiceId={playingVoiceId}
onPlayPreview={handlePlayPreview}
onStopPreview={handleStopPreview}
/>
</div>

{/* Advanced Options (Collapsible) */}
Expand Down
Loading
Loading