diff --git a/src/rai_asr/rai_asr/asr_node.py b/src/rai_asr/rai_asr/asr_node.py index 4354c290..7212c344 100755 --- a/src/rai_asr/rai_asr/asr_node.py +++ b/src/rai_asr/rai_asr/asr_node.py @@ -28,6 +28,7 @@ from rclpy.callback_groups import ReentrantCallbackGroup from rclpy.executors import SingleThreadedExecutor from rclpy.node import Node +from rclpy.qos import DurabilityPolicy, HistoryPolicy, QoSProfile, ReliabilityPolicy from scipy.signal import resample from std_msgs.msg import String @@ -208,8 +209,14 @@ def _initialize_parameters(self): self.get_logger().info("Parameters have been initialized") # type: ignore def _setup_publishers_and_subscribers(self): - - self.transcription_publisher = self.create_publisher(String, "/from_human", 10) + reliable_qos = QoSProfile( + reliability=ReliabilityPolicy.RELIABLE, + durability=DurabilityPolicy.TRANSIENT_LOCAL, + history=HistoryPolicy.KEEP_ALL, + ) + self.transcription_publisher = self.create_publisher( + String, "/from_human", qos_profile=reliable_qos + ) self.status_publisher = self.create_publisher(String, "/asr_status", 10) self.tts_status_subscriber = self.create_subscription( String, diff --git a/src/rai_hmi/rai_hmi/voice_hmi.py b/src/rai_hmi/rai_hmi/voice_hmi.py index 09eb1bb5..caf70bdf 100644 --- a/src/rai_hmi/rai_hmi/voice_hmi.py +++ b/src/rai_hmi/rai_hmi/voice_hmi.py @@ -14,6 +14,7 @@ # import logging +import re import threading import time from queue import Queue @@ -23,6 +24,7 @@ from langchain_core.messages import HumanMessage from rclpy.callback_groups import ReentrantCallbackGroup from rclpy.executors import MultiThreadedExecutor +from rclpy.qos import DurabilityPolicy, HistoryPolicy, QoSProfile, ReliabilityPolicy from std_msgs.msg import String from rai.node import RaiBaseNode @@ -86,15 +88,22 @@ def __init__( super().__init__(node_name, queue, robot_description_package) self.callback_group = ReentrantCallbackGroup() + reliable_qos = QoSProfile( + reliability=ReliabilityPolicy.RELIABLE, + durability=DurabilityPolicy.TRANSIENT_LOCAL, + history=HistoryPolicy.KEEP_ALL, + ) self.hmi_subscription = self.create_subscription( String, "from_human", self.handle_human_message, - 10, + qos_profile=reliable_qos, ) - self.hmi_publisher = self.create_publisher( - String, "to_human", 10, callback_group=self.callback_group + String, + "to_human", + qos_profile=reliable_qos, + callback_group=self.callback_group, ) self.history = [] @@ -103,6 +112,12 @@ def __init__( def set_agent(self, agent): self.agent = agent + def split_and_publish(self, message: str): + sentences = re.split(r"(?<=\.)\s|[:!]", message) + for sentence in sentences: + if sentence: + self.hmi_publisher.publish(String(data=sentence)) + def handle_human_message(self, msg: String): self.processing = True self.get_logger().info("Processing started") @@ -118,7 +133,7 @@ def handle_human_message(self, msg: String): self.get_logger().info( f'Sending message to human: "{last_message}"' ) - self.hmi_publisher.publish(String(data=last_message)) + self.split_and_publish(last_message) self.get_logger().info("Processing finished") self.processing = False diff --git a/src/rai_tts/rai_tts/tts_clients.py b/src/rai_tts/rai_tts/tts_clients.py index 21974df7..ea2367bc 100644 --- a/src/rai_tts/rai_tts/tts_clients.py +++ b/src/rai_tts/rai_tts/tts_clients.py @@ -16,15 +16,19 @@ import logging import os import tempfile +import time from abc import abstractmethod from typing import Optional import requests from elevenlabs.client import ElevenLabs +from elevenlabs.types import Voice +from elevenlabs.types.voice_settings import VoiceSettings logger = logging.getLogger(__name__) -TTS_TRIES = 2 +TTS_TRIES = 5 +TTS_RETRY_DELAY = 0.5 class TTSClient: @@ -46,10 +50,19 @@ def save_audio_to_file(audio_data: bytes, suffix: str) -> str: class ElevenLabsClient(TTSClient): def __init__(self, voice: str, base_url: Optional[str] = None): self.base_url = base_url - self.voice = voice api_key = os.getenv(key="ELEVENLABS_API_KEY") self.client = ElevenLabs(base_url=None, api_key=api_key) + self.voice_settings = VoiceSettings( + stability=0.7, + similarity_boost=0.5, + ) + voices = self.client.voices.get_all().voices + voice_id = next((v.voice_id for v in voices if v.name == voice), None) + if voice_id is None: + raise ValueError(f"Voice {voice} not found") + self.voice = Voice(voice_id=voice_id, settings=self.voice_settings) + def synthesize_speech_to_file(self, text: str) -> str: tries = 0 while tries < TTS_TRIES: @@ -62,9 +75,15 @@ def synthesize_speech_to_file(self, text: str) -> str: audio_data = b"".join(response) return self.save_audio_to_file(audio_data, suffix=".mp3") except Exception as e: - logger.warn(f"Error occurred during sythesizing speech: {e}.") # type: ignore + logger.warn(f"Error occurred during synthesizing speech: {e}.") # type: ignore tries += 1 - audio_data = b"".join(response) + if tries == TTS_TRIES: + logger.error( + f"Failed to synthesize speech after {TTS_TRIES} tries. Creating empty audio file instead." + ) + time.sleep(TTS_RETRY_DELAY) + + audio_data = b"" return self.save_audio_to_file(audio_data, suffix=".mp3") diff --git a/src/rai_tts/rai_tts/tts_node.py b/src/rai_tts/rai_tts/tts_node.py index c2901ae1..e4946832 100644 --- a/src/rai_tts/rai_tts/tts_node.py +++ b/src/rai_tts/rai_tts/tts_node.py @@ -22,6 +22,7 @@ import rclpy from rclpy.node import Node +from rclpy.qos import DurabilityPolicy, HistoryPolicy, QoSProfile, ReliabilityPolicy from std_msgs.msg import String from .tts_clients import ElevenLabsClient, OpenTTSClient, TTSClient @@ -42,9 +43,13 @@ def __init__(self): self.declare_parameter("topic", "to_human") topic_param = self.get_parameter("topic").get_parameter_value().string_value # type: ignore - + reliable_qos = QoSProfile( + reliability=ReliabilityPolicy.RELIABLE, + durability=DurabilityPolicy.TRANSIENT_LOCAL, + history=HistoryPolicy.KEEP_ALL, + ) self.subscription = self.create_subscription( # type: ignore - String, topic_param, self.listener_callback, 10 # type: ignore + String, topic_param, self.listener_callback, qos_profile=reliable_qos # type: ignore ) self.playing = False self.status_publisher = self.create_publisher(String, "tts_status", 10) # type: ignore @@ -73,6 +78,7 @@ def listener_callback(self, msg: String): self.get_logger().info( # type: ignore f"Registering new TTS job: {self.job_id} length: {len(msg.data)} chars." # type: ignore ) + self.get_logger().debug(f"The job: {msg.data}") # type: ignore threading.Thread( target=self.start_synthesize_thread, args=(msg, self.job_id) # type: ignore @@ -130,7 +136,6 @@ def _play_audio(self, filepath: str): stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) - self.get_logger().debug(f"Playing audio: {filepath}") # type: ignore self.playing = False def _initialize_client(self) -> TTSClient: