Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(rai_tts&rai_hmi): stable elevenlabs voice, reintroduced chunking #247

Merged
merged 4 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/rai_asr/rai_asr/asr_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from rclpy.callback_groups import ReentrantCallbackGroup
from rclpy.executors import SingleThreadedExecutor
from rclpy.node import Node
from rclpy.qos import DurabilityPolicy, HistoryPolicy, QoSProfile, ReliabilityPolicy
from scipy.signal import resample
from std_msgs.msg import String

Expand Down Expand Up @@ -208,8 +209,14 @@ def _initialize_parameters(self):
self.get_logger().info("Parameters have been initialized") # type: ignore

def _setup_publishers_and_subscribers(self):

self.transcription_publisher = self.create_publisher(String, "/from_human", 10)
reliable_qos = QoSProfile(
reliability=ReliabilityPolicy.RELIABLE,
durability=DurabilityPolicy.TRANSIENT_LOCAL,
history=HistoryPolicy.KEEP_ALL,
)
self.transcription_publisher = self.create_publisher(
String, "/from_human", qos_profile=reliable_qos
)
self.status_publisher = self.create_publisher(String, "/asr_status", 10)
self.tts_status_subscriber = self.create_subscription(
String,
Expand Down
23 changes: 19 additions & 4 deletions src/rai_hmi/rai_hmi/voice_hmi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#

import logging
import re
import threading
import time
from queue import Queue
Expand All @@ -23,6 +24,7 @@
from langchain_core.messages import HumanMessage
from rclpy.callback_groups import ReentrantCallbackGroup
from rclpy.executors import MultiThreadedExecutor
from rclpy.qos import DurabilityPolicy, HistoryPolicy, QoSProfile, ReliabilityPolicy
from std_msgs.msg import String

from rai.node import RaiBaseNode
Expand Down Expand Up @@ -86,15 +88,22 @@ def __init__(
super().__init__(node_name, queue, robot_description_package)

self.callback_group = ReentrantCallbackGroup()
reliable_qos = QoSProfile(
reliability=ReliabilityPolicy.RELIABLE,
durability=DurabilityPolicy.TRANSIENT_LOCAL,
history=HistoryPolicy.KEEP_ALL,
)
self.hmi_subscription = self.create_subscription(
String,
"from_human",
self.handle_human_message,
10,
qos_profile=reliable_qos,
)

self.hmi_publisher = self.create_publisher(
String, "to_human", 10, callback_group=self.callback_group
String,
"to_human",
qos_profile=reliable_qos,
callback_group=self.callback_group,
)
self.history = []

Expand All @@ -103,6 +112,12 @@ def __init__(
def set_agent(self, agent):
self.agent = agent

def split_and_publish(self, message: str):
sentences = re.split(r"(?<=\.)\s|[:!]", message)
for sentence in sentences:
if sentence:
self.hmi_publisher.publish(String(data=sentence))

def handle_human_message(self, msg: String):
self.processing = True
self.get_logger().info("Processing started")
Expand All @@ -118,7 +133,7 @@ def handle_human_message(self, msg: String):
self.get_logger().info(
f'Sending message to human: "{last_message}"'
)
self.hmi_publisher.publish(String(data=last_message))
self.split_and_publish(last_message)

self.get_logger().info("Processing finished")
self.processing = False
Expand Down
27 changes: 23 additions & 4 deletions src/rai_tts/rai_tts/tts_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,19 @@
import logging
import os
import tempfile
import time
from abc import abstractmethod
from typing import Optional

import requests
from elevenlabs.client import ElevenLabs
from elevenlabs.types import Voice
from elevenlabs.types.voice_settings import VoiceSettings

logger = logging.getLogger(__name__)

TTS_TRIES = 2
TTS_TRIES = 5
TTS_RETRY_DELAY = 0.5


class TTSClient:
Expand All @@ -46,10 +50,19 @@ def save_audio_to_file(audio_data: bytes, suffix: str) -> str:
class ElevenLabsClient(TTSClient):
def __init__(self, voice: str, base_url: Optional[str] = None):
self.base_url = base_url
self.voice = voice
api_key = os.getenv(key="ELEVENLABS_API_KEY")
self.client = ElevenLabs(base_url=None, api_key=api_key)

self.voice_settings = VoiceSettings(
stability=0.7,
similarity_boost=0.5,
)
voices = self.client.voices.get_all().voices
voice_id = next((v.voice_id for v in voices if v.name == voice), None)
if voice_id is None:
raise ValueError(f"Voice {voice} not found")
self.voice = Voice(voice_id=voice_id, settings=self.voice_settings)

def synthesize_speech_to_file(self, text: str) -> str:
tries = 0
while tries < TTS_TRIES:
Expand All @@ -62,9 +75,15 @@ def synthesize_speech_to_file(self, text: str) -> str:
audio_data = b"".join(response)
return self.save_audio_to_file(audio_data, suffix=".mp3")
except Exception as e:
logger.warn(f"Error occurred during sythesizing speech: {e}.") # type: ignore
logger.warn(f"Error occurred during synthesizing speech: {e}.") # type: ignore
tries += 1
audio_data = b"".join(response)
if tries == TTS_TRIES:
logger.error(
f"Failed to synthesize speech after {TTS_TRIES} tries. Creating empty audio file instead."
)
time.sleep(TTS_RETRY_DELAY)

audio_data = b""
return self.save_audio_to_file(audio_data, suffix=".mp3")


Expand Down
11 changes: 8 additions & 3 deletions src/rai_tts/rai_tts/tts_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import rclpy
from rclpy.node import Node
from rclpy.qos import DurabilityPolicy, HistoryPolicy, QoSProfile, ReliabilityPolicy
from std_msgs.msg import String

from .tts_clients import ElevenLabsClient, OpenTTSClient, TTSClient
Expand All @@ -42,9 +43,13 @@ def __init__(self):
self.declare_parameter("topic", "to_human")

topic_param = self.get_parameter("topic").get_parameter_value().string_value # type: ignore

reliable_qos = QoSProfile(
reliability=ReliabilityPolicy.RELIABLE,
durability=DurabilityPolicy.TRANSIENT_LOCAL,
history=HistoryPolicy.KEEP_ALL,
)
self.subscription = self.create_subscription( # type: ignore
String, topic_param, self.listener_callback, 10 # type: ignore
String, topic_param, self.listener_callback, qos_profile=reliable_qos # type: ignore
)
self.playing = False
self.status_publisher = self.create_publisher(String, "tts_status", 10) # type: ignore
Expand Down Expand Up @@ -73,6 +78,7 @@ def listener_callback(self, msg: String):
self.get_logger().info( # type: ignore
f"Registering new TTS job: {self.job_id} length: {len(msg.data)} chars." # type: ignore
)
self.get_logger().debug(f"The job: {msg.data}") # type: ignore

threading.Thread(
target=self.start_synthesize_thread, args=(msg, self.job_id) # type: ignore
Expand Down Expand Up @@ -130,7 +136,6 @@ def _play_audio(self, filepath: str):
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
self.get_logger().debug(f"Playing audio: {filepath}") # type: ignore
self.playing = False

def _initialize_client(self) -> TTSClient:
Expand Down