From 8b42d7051f25bf10f54991eabf60862d3fb052db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Thu, 24 Apr 2025 15:39:21 +0200 Subject: [PATCH 01/10] feat: add configuration dataclasses for asr --- src/rai_asr/rai_asr/agents/initialization.py | 46 ++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 src/rai_asr/rai_asr/agents/initialization.py diff --git a/src/rai_asr/rai_asr/agents/initialization.py b/src/rai_asr/rai_asr/agents/initialization.py new file mode 100644 index 000000000..970921003 --- /dev/null +++ b/src/rai_asr/rai_asr/agents/initialization.py @@ -0,0 +1,46 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + + +@dataclass +class VADConfig: + model_name: str = "SileroVAD" + threshold: float = 0.5 + sampling_rate: int = 1600 + + +@dataclass +class WWConfig: + model_name: str = "OpenWakeWord" + threshold: float = 0.01 + + +@dataclass +class TranscribeConfig: + model_name: str = "LocalWhisper" + + +@dataclass +class MicrophoneConfig: + device_name: str + + +@dataclass +class ASRAgentConfig: + voice_activity_detection: VADConfig + wakeword: WWConfig + transcribe: TranscribeConfig + microphone: MicrophoneConfig From c8352155724f9d13f2f714962508b721e905ed65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Fri, 25 Apr 2025 15:44:51 +0200 Subject: [PATCH 02/10] feat: add load config, plus additional typing --- src/rai_asr/rai_asr/agents/initialization.py | 37 ++++++++++++++++++-- 1 file changed, 34 insertions(+), 3 deletions(-) diff --git a/src/rai_asr/rai_asr/agents/initialization.py b/src/rai_asr/rai_asr/agents/initialization.py index 970921003..7585261ca 100644 --- a/src/rai_asr/rai_asr/agents/initialization.py +++ b/src/rai_asr/rai_asr/agents/initialization.py @@ -13,24 +13,29 @@ # limitations under the License. from dataclasses import dataclass +from typing import Literal, Optional + +import tomli @dataclass class VADConfig: - model_name: str = "SileroVAD" + model_name: Literal["SileroVAD"] = "SileroVAD" threshold: float = 0.5 sampling_rate: int = 1600 @dataclass class WWConfig: - model_name: str = "OpenWakeWord" + model_name: Literal["OpenWakeWord"] = "OpenWakeWord" threshold: float = 0.01 @dataclass class TranscribeConfig: - model_name: str = "LocalWhisper" + model_name: Literal["LocalWhisper", "FasterWhisper", "OpenAIWhisper"] = ( + "LocalWhisper" + ) @dataclass @@ -44,3 +49,29 @@ class ASRAgentConfig: wakeword: WWConfig transcribe: TranscribeConfig microphone: MicrophoneConfig + + +def load_config(config_path: Optional[str] = None) -> ASRAgentConfig: + if config_path is None: + with open("config.toml", "rb") as f: + config_dict = tomli.load(f) + else: + with open(config_path, "rb") as f: + config_dict = tomli.load(f) + return ASRAgentConfig( + voice_activity_detection=VADConfig( + model_name=config_dict["voice_activity_detection"]["model_name"], + threshold=config_dict["voice_activity_detection"]["threshold"], + sampling_rate=config_dict["voice_activity_detection"]["sampling_rate"], + ), + wakeword=WWConfig( + model_name=config_dict["wakeword"]["model_name"], + threshold=config_dict["wakeword"]["threshold"], + ), + transcribe=TranscribeConfig( + model_name=config_dict["transcribe"]["model_name"], + ), + microphone=MicrophoneConfig( + device_name=config_dict["microphone"]["device_name"], + ), + ) From 4b1b75b721c923cb9d0fa9b71022acdedcbdf0c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Fri, 25 Apr 2025 15:51:38 +0200 Subject: [PATCH 03/10] feat: update inits --- src/rai_asr/rai_asr/__init__.py | 21 +++++++++++++++++++++ src/rai_asr/rai_asr/agents/__init__.py | 14 ++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/src/rai_asr/rai_asr/__init__.py b/src/rai_asr/rai_asr/__init__.py index 499007aa9..60eee367c 100644 --- a/src/rai_asr/rai_asr/__init__.py +++ b/src/rai_asr/rai_asr/__init__.py @@ -15,3 +15,24 @@ """RAI ASR package.""" __version__ = "0.1.0" + + +from rai_asr.agents.asr_agent import SpeechRecognitionAgent +from rai_asr.agents.initialization import ( + ASRAgentConfig, + MicrophoneConfig, + TranscribeConfig, + VADConfig, + WWConfig, + load_config, +) + +__all__ = [ + "ASRAgentConfig", + "MicrophoneConfig", + "SpeechRecognitionAgent", + "TranscribeConfig", + "VADConfig", + "WWConfig", + "load_config", +] diff --git a/src/rai_asr/rai_asr/agents/__init__.py b/src/rai_asr/rai_asr/agents/__init__.py index 00a79d27e..2f7c965db 100644 --- a/src/rai_asr/rai_asr/agents/__init__.py +++ b/src/rai_asr/rai_asr/agents/__init__.py @@ -13,7 +13,21 @@ # limitations under the License. from rai_asr.agents.asr_agent import SpeechRecognitionAgent +from rai_asr.agents.initialization import ( + ASRAgentConfig, + MicrophoneConfig, + TranscribeConfig, + VADConfig, + WWConfig, + load_config, +) __all__ = [ + "ASRAgentConfig", + "MicrophoneConfig", "SpeechRecognitionAgent", + "TranscribeConfig", + "VADConfig", + "WWConfig", + "load_config", ] From 87aadf4a6e95f92dd30efed1ff8bd1b47fd33e1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Fri, 25 Apr 2025 16:14:50 +0200 Subject: [PATCH 04/10] feat: finalize cfg structure --- src/rai_asr/rai_asr/agents/initialization.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/rai_asr/rai_asr/agents/initialization.py b/src/rai_asr/rai_asr/agents/initialization.py index 7585261ca..454ab2e91 100644 --- a/src/rai_asr/rai_asr/agents/initialization.py +++ b/src/rai_asr/rai_asr/agents/initialization.py @@ -22,13 +22,14 @@ class VADConfig: model_name: Literal["SileroVAD"] = "SileroVAD" threshold: float = 0.5 - sampling_rate: int = 1600 + silence_grace_period: float = 0.3 @dataclass class WWConfig: model_name: Literal["OpenWakeWord"] = "OpenWakeWord" threshold: float = 0.01 + is_used: bool = False @dataclass @@ -36,6 +37,7 @@ class TranscribeConfig: model_name: Literal["LocalWhisper", "FasterWhisper", "OpenAIWhisper"] = ( "LocalWhisper" ) + language: str = "en" @dataclass @@ -60,18 +62,19 @@ def load_config(config_path: Optional[str] = None) -> ASRAgentConfig: config_dict = tomli.load(f) return ASRAgentConfig( voice_activity_detection=VADConfig( - model_name=config_dict["voice_activity_detection"]["model_name"], - threshold=config_dict["voice_activity_detection"]["threshold"], - sampling_rate=config_dict["voice_activity_detection"]["sampling_rate"], + model_name=config_dict["asr"]["vad_model"], + threshold=config_dict["asr"]["vad_threshold"], + silence_grace_period=config_dict["asr"]["silence_grace_period"], ), wakeword=WWConfig( - model_name=config_dict["wakeword"]["model_name"], - threshold=config_dict["wakeword"]["threshold"], + model_name=config_dict["asr"]["wake_word_model"], + threshold=config_dict["asr"]["wake_word_threshold"], + is_used=config_dict["asr"]["use_wake_word"], ), transcribe=TranscribeConfig( - model_name=config_dict["transcribe"]["model_name"], + model_name=config_dict["asr"]["transcription_model"], ), microphone=MicrophoneConfig( - device_name=config_dict["microphone"]["device_name"], + device_name=config_dict["asr"]["recording_device_name"], ), ) From 7838bb240fd55cf1dc17e0e544d766001ab4d397 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Mon, 28 Apr 2025 17:06:52 +0200 Subject: [PATCH 05/10] feat: update configurator.py to reflect asr config structure --- config.toml | 7 +- src/rai_asr/rai_asr/__init__.py | 2 + src/rai_asr/rai_asr/agents/__init__.py | 2 + src/rai_asr/rai_asr/agents/initialization.py | 12 +- src/rai_core/rai/frontend/configurator.py | 164 +++++++++++-------- 5 files changed, 113 insertions(+), 74 deletions(-) diff --git a/config.toml b/config.toml index 776e070d6..0fa63023b 100644 --- a/config.toml +++ b/config.toml @@ -13,7 +13,7 @@ region_name = "us-east-1" simple_model = "gpt-4o-mini" complex_model = "gpt-4o" embeddings_model = "text-embedding-ada-002" -base_url = "https://api.openai.com/v1/" # for openai compatible apis +base_url = "https://api.openai.com/v1/" [ollama] simple_model = "llama3.2" @@ -34,11 +34,12 @@ host = "https://api.smith.langchain.com" [asr] recording_device_name = "default" -vendor = "whisper" +transcription_model = "LocalWhisper" language = "en" +vad_model = "SileroVAD" silence_grace_period = 0.3 -use_wake_word = false vad_threshold = 0.3 +use_wake_word = false wake_word_model = "" wake_word_threshold = 0.5 diff --git a/src/rai_asr/rai_asr/__init__.py b/src/rai_asr/rai_asr/__init__.py index 60eee367c..911a22087 100644 --- a/src/rai_asr/rai_asr/__init__.py +++ b/src/rai_asr/rai_asr/__init__.py @@ -25,6 +25,7 @@ VADConfig, WWConfig, load_config, + TRANSCRIBE_MODELS, ) __all__ = [ @@ -35,4 +36,5 @@ "VADConfig", "WWConfig", "load_config", + "TRANSCRIBE_MODELS", ] diff --git a/src/rai_asr/rai_asr/agents/__init__.py b/src/rai_asr/rai_asr/agents/__init__.py index 2f7c965db..164c0d819 100644 --- a/src/rai_asr/rai_asr/agents/__init__.py +++ b/src/rai_asr/rai_asr/agents/__init__.py @@ -20,6 +20,7 @@ VADConfig, WWConfig, load_config, + TRANSCRIBE_MODELS, ) __all__ = [ @@ -30,4 +31,5 @@ "VADConfig", "WWConfig", "load_config", + "TRANSCRIBE_MODELS", ] diff --git a/src/rai_asr/rai_asr/agents/initialization.py b/src/rai_asr/rai_asr/agents/initialization.py index 454ab2e91..612d1cfb1 100644 --- a/src/rai_asr/rai_asr/agents/initialization.py +++ b/src/rai_asr/rai_asr/agents/initialization.py @@ -32,13 +32,18 @@ class WWConfig: is_used: bool = False +TRANSCRIBE_MODELS = ["LocalWhisper (Free)", "FasterWhisper (Free)", "OpenAI (Cloud)"] + + @dataclass class TranscribeConfig: - model_name: Literal["LocalWhisper", "FasterWhisper", "OpenAIWhisper"] = ( - "LocalWhisper" - ) + model_name: str = TRANSCRIBE_MODELS[0] language: str = "en" + def __post_init__(self): + if self.model_name not in TRANSCRIBE_MODELS: + raise ValueError(f"model_name must be one of {TRANSCRIBE_MODELS}") + @dataclass class MicrophoneConfig: @@ -73,6 +78,7 @@ def load_config(config_path: Optional[str] = None) -> ASRAgentConfig: ), transcribe=TranscribeConfig( model_name=config_dict["asr"]["transcription_model"], + language=config_dict["asr"]["language"], ), microphone=MicrophoneConfig( device_name=config_dict["asr"]["recording_device_name"], diff --git a/src/rai_core/rai/frontend/configurator.py b/src/rai_core/rai/frontend/configurator.py index 5f96d6bba..be7324bd2 100644 --- a/src/rai_core/rai/frontend/configurator.py +++ b/src/rai_core/rai/frontend/configurator.py @@ -27,64 +27,16 @@ from langchain_ollama import ChatOllama, OllamaEmbeddings from langchain_openai import ChatOpenAI, OpenAIEmbeddings -# Initialize session state for tracking steps if not exists -if "current_step" not in st.session_state: - st.session_state.current_step = 1 -if "config" not in st.session_state: - # Load initial config from TOML file - try: - with open("config.toml", "rb") as f: - st.session_state.config = tomli.load(f) - except FileNotFoundError: - raise FileNotFoundError("config.toml not found. Please recreate it.") - -# Sidebar progress tracker -st.sidebar.title("Configuration Progress") -steps = { - 1: "👋 Welcome", - 2: "🤖 Model Selection", - 3: "📊 Tracing", - 4: "🎙️ Speech Recognition", - 5: "🔊 Text to Speech", - 6: "🎯 Additional Features", - 7: "✅ Review & Save", -} -# Replace the existing step display with clickable elements -for step_num, step_name in steps.items(): - if step_num == st.session_state.current_step: - # Current step is bold and has an arrow - if st.sidebar.button( - step_name, key=f"step_{step_num}", use_container_width=True - ): - st.session_state.current_step = step_num - else: - # Other steps are clickable but not highlighted - if st.sidebar.button( - step_name, key=f"step_{step_num}", use_container_width=True - ): - st.session_state.current_step = step_num - - -# Navigation buttons -def next_step(): - st.session_state.current_step = st.session_state.current_step + 1 - - -def prev_step(): - st.session_state.current_step = st.session_state.current_step - 1 - - -# Main content based on current step -if st.session_state.current_step == 1: +def welcome(): st.title("Welcome to RAI Configurator! 👋") st.markdown( """ This wizard will help you set up your RAI environment step by step: 1. Configure your AI models and vendor 2. Set up model tracing and monitoring - 3. Configure speech recognition (ASR) - 4. Set up text-to-speech (TTS) + 3. Configure speech recognition (ASR) (if installed) + 4. Set up text-to-speech (TTS) (if installed) 5. Enable additional features 6. Review and save your configuration @@ -94,7 +46,8 @@ def prev_step(): st.button("Begin Configuration →", on_click=next_step) -elif st.session_state.current_step == 2: + +def model_selection(): st.title("Model Configuration") st.info( """ @@ -302,7 +255,8 @@ def on_model_vendor_change(model_type: str): with col2: st.button("Next →", on_click=next_step) -elif st.session_state.current_step == 3: + +def tracing(): st.title("Tracing Configuration") st.info( """ @@ -394,7 +348,9 @@ def on_langsmith_change(): with col2: st.button("Next →", on_click=next_step) -elif st.session_state.current_step == 4: + +def asr(): + from rai_asr import TRANSCRIBE_MODELS def on_recording_device_change(): st.session_state.config["asr"]["recording_device_name"] = ( @@ -407,7 +363,7 @@ def on_asr_vendor_change(): if st.session_state.asr_vendor_select == "Local Whisper (Free)" else "openai" ) - st.session_state.config["asr"]["vendor"] = vendor + st.session_state.config["asr"]["transcription_model"] = vendor def on_language_change(): st.session_state.config["asr"]["language"] = st.session_state.language_input @@ -485,30 +441,29 @@ def get_recording_devices(reinitialize: bool = False) -> List[Dict[str, str | in recording_devices = get_recording_devices(reinitialize=True) # Get the current vendor from config and convert to display name - current_vendor = st.session_state.config.get("asr", {}).get("vendor", "whisper") - vendor_display_name = ( - "Local Whisper (Free)" if current_vendor == "whisper" else "OpenAI (Cloud)" + current_vendor = st.session_state.config.get("asr", {}).get( + "transciption_model", TRANSCRIBE_MODELS[0] ) asr_vendor = st.selectbox( "Choose your ASR vendor", - ["Local Whisper (Free)", "OpenAI (Cloud)"], + TRANSCRIBE_MODELS, placeholder="Select vendor", - index=["Local Whisper (Free)", "OpenAI (Cloud)"].index(vendor_display_name), + index=TRANSCRIBE_MODELS.index(current_vendor), key="asr_vendor_select", on_change=on_asr_vendor_change, ) - if asr_vendor == "Local Whisper (Free)": + if asr_vendor == "OpenAI (Cloud)": st.info( """ - Local Whisper is recommended to use when Nvidia GPU is available. + OpenAI ASR uses the OpenAI API. Make sure to set `OPENAI_API_KEY` environment variable. """ ) - elif asr_vendor == "OpenAI (Cloud)": + else: st.info( - """ - OpenAI ASR uses the OpenAI API. Make sure to set `OPENAI_API_KEY` environment variable. + f""" + {asr_vendor} is recommended to use when Nvidia GPU is available. """ ) @@ -577,8 +532,8 @@ def get_recording_devices(reinitialize: bool = False) -> List[Dict[str, str | in with col2: st.button("Next →", on_click=next_step) -elif st.session_state.current_step == 5: +def tts(): def on_tts_vendor_change(): vendor = ( "elevenlabs" @@ -663,7 +618,8 @@ def on_keep_speaker_busy_change(): with col2: st.button("Next →", on_click=next_step) -elif st.session_state.current_step == 6: + +def additional_features(): st.title("Additional Features Configuration") st.info( """ @@ -718,7 +674,8 @@ def on_keep_speaker_busy_change(): with col2: st.button("Next →", on_click=next_step) -elif st.session_state.current_step == 7: + +def review_and_save(): st.title("Review & Save Configuration") st.write( """ @@ -913,3 +870,74 @@ def test_recording_device(index: int, sample_rate: int): with open("config.toml", "wb") as f: tomli_w.dump(st.session_state.config, f) st.success("Configuration saved successfully!") + + +@st.cache_data +def setup_steps(): + step_names = ["👋 Welcome", "🤖 Model Selection", "📊 Tracing"] + step_render = [welcome, model_selection, tracing] + + try: + from rai_asr import TRANSCRIBE_MODELS + + step_names.append("🎙️ Speech Recognition") + step_render.append(asr) + except ImportError: + pass + + step_names.extend( + [ + "🔊 Text to Speech", + "🎯 Additional Features", + "✅ Review & Save", + ] + ) + step_render.extend([tts, additional_features, review_and_save]) + + steps = dict(enumerate(step_names)) + step_renderer = dict(enumerate(step_render)) + return steps, step_renderer + + +# Initialize session state for tracking steps if not exists +if "current_step" not in st.session_state: + st.session_state.current_step = 1 +if "config" not in st.session_state: + # Load initial config from TOML file + try: + with open("config.toml", "rb") as f: + st.session_state.config = tomli.load(f) + except FileNotFoundError: + raise FileNotFoundError("config.toml not found. Please recreate it.") + +# Sidebar progress tracker +st.sidebar.title("Configuration Progress") +steps, step_renderer = setup_steps() + +# Replace the existing step display with clickable elements +for step_num, step_name in steps.items(): + if step_num == st.session_state.current_step: + # Current step is bold and has an arrow + if st.sidebar.button( + step_name, key=f"step_{step_num}", use_container_width=True + ): + st.session_state.current_step = step_num + else: + # Other steps are clickable but not highlighted + if st.sidebar.button( + step_name, key=f"step_{step_num}", use_container_width=True + ): + st.session_state.current_step = step_num + + +# Navigation buttons +def next_step(): + st.session_state.current_step = st.session_state.current_step + 1 + + +def prev_step(): + st.session_state.current_step = st.session_state.current_step - 1 + + +# Main content based on current step +step_renderer[st.session_state.current_step]() From 6a2ceff51a6563c761ed3c476de0adbc2c508231 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Mon, 28 Apr 2025 17:40:13 +0200 Subject: [PATCH 06/10] feat: add from config initialization to ASR agent --- src/rai_asr/rai_asr/agents/asr_agent.py | 49 +++++++++++++++++++++++++ src/rai_asr/rai_asr/models/__init__.py | 3 +- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/src/rai_asr/rai_asr/agents/asr_agent.py b/src/rai_asr/rai_asr/agents/asr_agent.py index 44c518a39..6c7ce574a 100644 --- a/src/rai_asr/rai_asr/agents/asr_agent.py +++ b/src/rai_asr/rai_asr/agents/asr_agent.py @@ -17,6 +17,7 @@ import time from threading import Event, Lock, Thread from typing import Any, List, Optional, TypedDict +from typing_extensions import Self from uuid import uuid4 import numpy as np @@ -35,6 +36,7 @@ ) from rai_asr.models import BaseTranscriptionModel, BaseVoiceDetectionModel +from .initialization import ASRAgentConfig, load_config class ThreadData(TypedDict): @@ -109,6 +111,53 @@ def __init__( self.transcription_buffers: dict[str, list[NDArray]] = {} self.is_playing = True + @classmethod + def from_config(cls, cfg_path: Optional[str] = None) -> Self: + cfg = load_config(cfg_path) + microphone_configuration = SoundDeviceConfig( + stream=True, + channels=1, + device_name=cfg.microphone.device_name, + block_size=1280, + consumer_sampling_rate=16000, + dtype="int16", + device_number=None, + is_input=True, + is_output=False, + ) + match cfg.transcribe.model_name: + case "LocalWhisper (Free)": + from rai_asr.models import LocalWhisper + + model = LocalWhisper("tiny", 16000, language=cfg.transcribe.language) + case "FasterWhisper (Free)": + from rai_asr.models import FasterWhisper + + model = FasterWhisper("tiny", 16000, language=cfg.transcribe.language) + case "OpenAI (Cloud)": + from rai_asr.models import OpenAIWhisper + + model = OpenAIWhisper("tiny", 16000, language=cfg.transcribe.language) + case _: + raise ValueError(f"Unknown model name f{cfg.transcribe.model_name}") + + match cfg.voice_activity_detection.model_name: + case "SileroVAD": + from rai_asr.models import SileroVAD + + vad = SileroVAD(16000, cfg.voice_activity_detection.threshold) + + agent = cls(microphone_configuration, "rai_auto_asr_agent", model, vad) + if cfg.wakeword.is_used: + match cfg.wakeword.model_name: + case "OpenWakeWord": + from rai_asr.models import OpenWakeWord + + agent.add_detection_model( + OpenWakeWord("hey jarvis", cfg.wakeword.threshold) + ) + return agent + def __call__(self): self.run() diff --git a/src/rai_asr/rai_asr/models/__init__.py b/src/rai_asr/rai_asr/models/__init__.py index daaa95a15..84f9bccd7 100644 --- a/src/rai_asr/rai_asr/models/__init__.py +++ b/src/rai_asr/rai_asr/models/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from rai_asr.models.base import BaseTranscriptionModel, BaseVoiceDetectionModel -from rai_asr.models.local_whisper import LocalWhisper +from rai_asr.models.local_whisper import LocalWhisper, FasterWhisper from rai_asr.models.open_ai_whisper import OpenAIWhisper from rai_asr.models.open_wake_word import OpenWakeWord from rai_asr.models.silero_vad import SileroVAD @@ -22,6 +22,7 @@ "BaseTranscriptionModel", "BaseVoiceDetectionModel", "LocalWhisper", + "FasterWhisper", "OpenAIWhisper", "OpenWakeWord", "SileroVAD", From a28fc6ed3bbf8da287b4940528b963f7cc48d733 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Tue, 29 Apr 2025 12:37:31 +0200 Subject: [PATCH 07/10] feat: add specific model selection --- config.toml | 2 + src/rai_asr/rai_asr/agents/asr_agent.py | 18 ++++++--- src/rai_asr/rai_asr/agents/initialization.py | 20 ++++++---- src/rai_core/rai/frontend/configurator.py | 40 ++++++++++++++++---- 4 files changed, 60 insertions(+), 20 deletions(-) diff --git a/config.toml b/config.toml index 0fa63023b..d2932f9ca 100644 --- a/config.toml +++ b/config.toml @@ -42,6 +42,8 @@ vad_threshold = 0.3 use_wake_word = false wake_word_model = "" wake_word_threshold = 0.5 +wake_word_model_name = "" +transcription_model_name = "tiny" [tts] vendor = "elevenlabs" diff --git a/src/rai_asr/rai_asr/agents/asr_agent.py b/src/rai_asr/rai_asr/agents/asr_agent.py index 6c7ce574a..79ec336d5 100644 --- a/src/rai_asr/rai_asr/agents/asr_agent.py +++ b/src/rai_asr/rai_asr/agents/asr_agent.py @@ -125,19 +125,25 @@ def from_config(cls, cfg_path: Optional[str] = None) -> Self: is_input=True, is_output=False, ) - match cfg.transcribe.model_name: + match cfg.transcribe.model_type: case "LocalWhisper (Free)": from rai_asr.models import LocalWhisper - model = LocalWhisper("tiny", 16000, language=cfg.transcribe.language) + model = LocalWhisper( + cfg.transcribe.model_name, 16000, language=cfg.transcribe.language + ) case "FasterWhisper (Free)": from rai_asr.models import FasterWhisper - model = FasterWhisper("tiny", 16000, language=cfg.transcribe.language) + model = FasterWhisper( + cfg.transcribe.model_name, 16000, language=cfg.transcribe.language + ) case "OpenAI (Cloud)": from rai_asr.models import OpenAIWhisper - model = OpenAIWhisper("tiny", 16000, language=cfg.transcribe.language) + model = OpenAIWhisper( + cfg.transcribe.model_name, 16000, language=cfg.transcribe.language + ) case _: raise ValueError(f"Unknown model name f{cfg.transcribe.model_name}") @@ -149,12 +155,12 @@ def from_config(cls, cfg_path: Optional[str] = None) -> Self: agent = cls(microphone_configuration, "rai_auto_asr_agent", model, vad) if cfg.wakeword.is_used: - match cfg.wakeword.model_name: + match cfg.wakeword.model_type: case "OpenWakeWord": from rai_asr.models import OpenWakeWord agent.add_detection_model( - OpenWakeWord("hey jarvis", cfg.wakeword.threshold) + OpenWakeWord(cfg.wakeword.model_name, cfg.wakeword.threshold) ) return agent diff --git a/src/rai_asr/rai_asr/agents/initialization.py b/src/rai_asr/rai_asr/agents/initialization.py index 612d1cfb1..9ad163db0 100644 --- a/src/rai_asr/rai_asr/agents/initialization.py +++ b/src/rai_asr/rai_asr/agents/initialization.py @@ -27,22 +27,26 @@ class VADConfig: @dataclass class WWConfig: - model_name: Literal["OpenWakeWord"] = "OpenWakeWord" + model_name: str = "hey jarvis" + model_type: Literal["OpenWakeWord"] = "OpenWakeWord" threshold: float = 0.01 is_used: bool = False -TRANSCRIBE_MODELS = ["LocalWhisper (Free)", "FasterWhisper (Free)", "OpenAI (Cloud)"] +TRANSCRIBE_MODELS = ["LocalWhisper", "FasterWhisper", "OpenAI"] @dataclass class TranscribeConfig: - model_name: str = TRANSCRIBE_MODELS[0] + model_type: str = TRANSCRIBE_MODELS[0] + model_name: str = "tiny" language: str = "en" def __post_init__(self): - if self.model_name not in TRANSCRIBE_MODELS: - raise ValueError(f"model_name must be one of {TRANSCRIBE_MODELS}") + if self.model_type not in TRANSCRIBE_MODELS: + raise ValueError( + f"unknown model_type: {self.model_type}. Must be one of {TRANSCRIBE_MODELS}" + ) @dataclass @@ -72,12 +76,14 @@ def load_config(config_path: Optional[str] = None) -> ASRAgentConfig: silence_grace_period=config_dict["asr"]["silence_grace_period"], ), wakeword=WWConfig( - model_name=config_dict["asr"]["wake_word_model"], + model_type=config_dict["asr"]["wake_word_model"], + model_name=config_dict["asr"]["wake_word_model_name"], threshold=config_dict["asr"]["wake_word_threshold"], is_used=config_dict["asr"]["use_wake_word"], ), transcribe=TranscribeConfig( - model_name=config_dict["asr"]["transcription_model"], + model_type=config_dict["asr"]["transcription_model"], + model_name=config_dict["asr"]["transcription_model_name"], language=config_dict["asr"]["language"], ), microphone=MicrophoneConfig( diff --git a/src/rai_core/rai/frontend/configurator.py b/src/rai_core/rai/frontend/configurator.py index be7324bd2..64959f816 100644 --- a/src/rai_core/rai/frontend/configurator.py +++ b/src/rai_core/rai/frontend/configurator.py @@ -358,12 +358,14 @@ def on_recording_device_change(): ) def on_asr_vendor_change(): - vendor = ( - "whisper" - if st.session_state.asr_vendor_select == "Local Whisper (Free)" - else "openai" + st.session_state.config["asr"]["transcription_model"] = ( + st.session_state.asr_vendor_select + ) + + def on_model_name_change(): + st.session_state.config["asr"]["transcription_model_name"] = ( + st.session_state.model_name_input ) - st.session_state.config["asr"]["transcription_model"] = vendor def on_language_change(): st.session_state.config["asr"]["language"] = st.session_state.language_input @@ -388,6 +390,11 @@ def on_wake_word_model_change(): st.session_state.wake_word_model_input ) + def on_wake_word_model_name_change(): + st.session_state.config["asr"]["wake_word_model_name"] = ( + st.session_state.wake_word_model_name_input + ) + def on_wake_word_threshold_change(): st.session_state.config["asr"]["wake_word_threshold"] = ( st.session_state.wake_word_threshold_input @@ -454,7 +461,7 @@ def get_recording_devices(reinitialize: bool = False) -> List[Dict[str, str | in on_change=on_asr_vendor_change, ) - if asr_vendor == "OpenAI (Cloud)": + if asr_vendor == "OpenAI": st.info( """ OpenAI ASR uses the OpenAI API. Make sure to set `OPENAI_API_KEY` environment variable. @@ -470,6 +477,14 @@ def get_recording_devices(reinitialize: bool = False) -> List[Dict[str, str | in # Add ASR parameters st.subheader("ASR Parameters") + model_name = st.text_input( + "Model name", + value=st.session_state.config.get("asr", {}).get("model_name", "tiny"), + help="Particular model architecture of the provided type, e.g. 'tiny'", + key="model_name_input", + on_change=on_model_name_change, + ) + language = st.text_input( "Language code", value=st.session_state.config.get("asr", {}).get("language", "en"), @@ -510,10 +525,21 @@ def get_recording_devices(reinitialize: bool = False) -> List[Dict[str, str | in wake_word_model = st.text_input( "Wake word model", value=st.session_state.config.get("asr", {}).get("wake_word_model", ""), - help="Wake word model to use", + help="Wake word model type to use", key="wake_word_model_input", on_change=on_wake_word_model_change, ) + + wake_word_model = st.text_input( + "Wake word model name", + value=st.session_state.config.get("asr", {}).get( + "wake_word_model_name", "" + ), + help="Specific wake word model to use", + key="wake_word_model_name_input", + on_change=on_wake_word_model_name_change, + ) + wake_word_threshold = st.slider( "Wake word threshold", min_value=0.0, From 4eeceba61c8601369232acdaa6634d6532352fa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Tue, 29 Apr 2025 15:59:32 +0200 Subject: [PATCH 08/10] feat: configuration capabilities to tts --- config.toml | 5 +- src/rai_core/rai/frontend/configurator.py | 111 +++++++++++++------ src/rai_tts/rai_tts/__init__.py | 13 ++- src/rai_tts/rai_tts/agents/__init__.py | 5 +- src/rai_tts/rai_tts/agents/initialization.py | 53 +++++++++ 5 files changed, 144 insertions(+), 43 deletions(-) create mode 100644 src/rai_tts/rai_tts/agents/initialization.py diff --git a/config.toml b/config.toml index d2932f9ca..b92db97fe 100644 --- a/config.toml +++ b/config.toml @@ -46,5 +46,6 @@ wake_word_model_name = "" transcription_model_name = "tiny" [tts] -vendor = "elevenlabs" -keep_speaker_busy = false +vendor = "ElevenLabs" +voice = "" +speaker_device_name = "default" diff --git a/src/rai_core/rai/frontend/configurator.py b/src/rai_core/rai/frontend/configurator.py index 64959f816..e6d483f9a 100644 --- a/src/rai_core/rai/frontend/configurator.py +++ b/src/rai_core/rai/frontend/configurator.py @@ -27,6 +27,26 @@ from langchain_ollama import ChatOllama, OllamaEmbeddings from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from rai.communication import sound_device + + +def get_sound_devices( + reinitialize: bool = False, output: bool = False +) -> List[Dict[str, str | int]]: + if reinitialize: + sd._terminate() + sd._initialize() + devices: List[Dict[str, str | int]] = sd.query_devices() + if output: + recording_devices = [ + device for device in devices if device.get("max_output_channels", 0) > 0 + ] + else: + recording_devices = [ + device for device in devices if device.get("max_input_channels", 0) > 0 + ] + return recording_devices + def welcome(): st.title("Welcome to RAI Configurator! 👋") @@ -413,17 +433,7 @@ def on_wake_word_threshold_change(): """ ) - def get_recording_devices(reinitialize: bool = False) -> List[Dict[str, str | int]]: - if reinitialize: - sd._terminate() - sd._initialize() - devices: List[Dict[str, str | int]] = sd.query_devices() - recording_devices = [ - device for device in devices if device.get("max_input_channels", 0) > 0 - ] - return recording_devices - - recording_devices = get_recording_devices() + recording_devices = get_sound_devices() currently_selected_device_name = st.session_state.config.get("asr", {}).get( "recording_device_name", "" ) @@ -445,7 +455,7 @@ def get_recording_devices(reinitialize: bool = False) -> List[Dict[str, str | in refresh_devices = st.button("Refresh devices") if refresh_devices: - recording_devices = get_recording_devices(reinitialize=True) + recording_devices = get_sound_devices(reinitialize=True) # Get the current vendor from config and convert to display name current_vendor = st.session_state.config.get("asr", {}).get( @@ -560,17 +570,17 @@ def get_recording_devices(reinitialize: bool = False) -> List[Dict[str, str | in def tts(): + from rai_tts import TTS_MODELS + def on_tts_vendor_change(): - vendor = ( - "elevenlabs" - if st.session_state.tts_vendor_select == "ElevenLabs (Cloud)" - else "opentts" - ) - st.session_state.config["tts"]["vendor"] = vendor + st.session_state.config["tts"]["vendor"] = st.session_state.tts_vendor_select - def on_keep_speaker_busy_change(): - st.session_state.config["tts"]["keep_speaker_busy"] = ( - st.session_state.keep_speaker_busy_checkbox + def on_voice_change(): + st.session_state.config["tts"]["voice"] = st.session_state.tts_voice_input + + def on_sound_device_change(): + st.session_state.config["tts"]["speaker_device_name"] = ( + st.session_state.sound_device_select ) # Ensure tts config exists @@ -586,22 +596,43 @@ def on_keep_speaker_busy_change(): """ ) - # Get the current vendor from config and convert to display name - current_vendor = st.session_state.config.get("tts", {}).get("vendor", "elevenlabs") - vendor_display_name = ( - "ElevenLabs (Cloud)" if current_vendor == "elevenlabs" else "OpenTTS (Local)" + sound_devices = get_sound_devices(output=True) + currently_selected_device_name = st.session_state.config.get("tts", {}).get( + "speaker_device_name", "" ) + try: + device_index = [device["name"] for device in sound_devices].index( + currently_selected_device_name + ) + except ValueError: + device_index = None + + recording_device_name = st.selectbox( + "Default speaker device", + [device["name"] for device in sound_devices], + placeholder="Select device", + index=device_index, + key="sound_device_select", + on_change=on_sound_device_change, + ) + + refresh_devices = st.button("Refresh devices") + if refresh_devices: + recording_devices = get_sound_devices(reinitialize=True, output=True) + + # Get the current vendor from config and convert to display name + current_vendor = st.session_state.config.get("tts", {}).get("vendor", TTS_MODELS[0]) tts_vendor = st.selectbox( "Choose your TTS vendor", - ["ElevenLabs (Cloud)", "OpenTTS (Local)"], + TTS_MODELS, placeholder="Select vendor", - index=["ElevenLabs (Cloud)", "OpenTTS (Local)"].index(vendor_display_name), + index=TTS_MODELS.index(current_vendor), key="tts_vendor_select", on_change=on_tts_vendor_change, ) - if tts_vendor == "ElevenLabs (Cloud)": + if tts_vendor == "ElevenLabs": st.info( """ Please ensure you have the following environment variable set: @@ -612,7 +643,7 @@ def on_keep_speaker_busy_change(): To get your API key, follow the instructions [here](https://elevenlabs.io/docs/api-reference/getting-started) """ ) - elif tts_vendor == "OpenTTS (Local)": + elif tts_vendor == "OpenTTS": st.info( """ Please ensure you have the Docker container running: @@ -624,11 +655,12 @@ def on_keep_speaker_busy_change(): """ ) - keep_speaker_busy = st.checkbox( - "Keep speaker busy", - value=st.session_state.config.get("tts", {}).get("keep_speaker_busy", False), - key="keep_speaker_busy_checkbox", - on_change=on_keep_speaker_busy_change, + model_name = st.text_input( + "Voice", + value=st.session_state.config.get("asr", {}).get("voice", ""), + help="Voice compatible with selected vendor. If left empty RAI will select a deafault value.", + key="tts_voice_input", + on_change=on_voice_change, ) st.info( @@ -911,14 +943,21 @@ def setup_steps(): except ImportError: pass + try: + from rai_tts import TTS_MODELS + + step_names.append("🔊 Text to Speech") + step_render.append(tts) + except ImportError as e: + pass + step_names.extend( [ - "🔊 Text to Speech", "🎯 Additional Features", "✅ Review & Save", ] ) - step_render.extend([tts, additional_features, review_and_save]) + step_render.extend([additional_features, review_and_save]) steps = dict(enumerate(step_names)) step_renderer = dict(enumerate(step_render)) diff --git a/src/rai_tts/rai_tts/__init__.py b/src/rai_tts/rai_tts/__init__.py index aa6023dfe..3ddb1bdde 100644 --- a/src/rai_tts/rai_tts/__init__.py +++ b/src/rai_tts/rai_tts/__init__.py @@ -12,7 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .agents import TextToSpeechAgent +from .agents import TextToSpeechAgent, TTS_MODELS, load_config, TTSAgentConfig from .models import ElevenLabsTTS, OpenTTS -__all__ = ["ElevenLabsTTS", "OpenTTS", "TextToSpeechAgent"] + +__all__ = [ + "ElevenLabsTTS", + "OpenTTS", + "TextToSpeechAgent", + "TextToSpeechAgent", + "TTSAgentConfig", + "load_config", + "TTS_MODELS", +] diff --git a/src/rai_tts/rai_tts/agents/__init__.py b/src/rai_tts/rai_tts/agents/__init__.py index 3d2b0ebc0..33ac95fc2 100644 --- a/src/rai_tts/rai_tts/agents/__init__.py +++ b/src/rai_tts/rai_tts/agents/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. from rai_tts.agents.tts_agent import TextToSpeechAgent +from rai_tts.agents.initialization import TTS_MODELS, load_config, TTSAgentConfig -__all__ = [ - "TextToSpeechAgent", -] +__all__ = ["TextToSpeechAgent", "TTSAgentConfig", "load_config", "TTS_MODELS"] diff --git a/src/rai_tts/rai_tts/agents/initialization.py b/src/rai_tts/rai_tts/agents/initialization.py new file mode 100644 index 000000000..376ba7dda --- /dev/null +++ b/src/rai_tts/rai_tts/agents/initialization.py @@ -0,0 +1,53 @@ +# Copyright (C) 2025 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Literal, Optional + +import tomli + + +@dataclass +class SpeakerConfig: + device_name: str + + +TTS_MODELS = ["OpenTTS", "ElevenLabs"] + + +@dataclass +class TTSConfig: + model_type: str = TTS_MODELS[0] + voice: str = "" + + +@dataclass +class TTSAgentConfig: + text_to_speech: TTSConfig + speaker: SpeakerConfig + + +def load_config(config_path: Optional[str] = None) -> TTSAgentConfig: + if config_path is None: + with open("config.toml", "rb") as f: + config_dict = tomli.load(f) + else: + with open(config_path, "rb") as f: + config_dict = tomli.load(f) + return TTSAgentConfig( + text_to_speech=TTSConfig( + model_type=config_dict["tts"]["vendor"], voice=config_dict["tts"]["voice"] + ), + speaker=SpeakerConfig(device_name=config_dict["tts"]["speaker_device_name"]), + ) From c40715c88a6f2e3cb28460cdfb846f58cb61a573 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Tue, 29 Apr 2025 16:49:33 +0200 Subject: [PATCH 09/10] feat: tts agent from config method --- src/rai_tts/rai_tts/agents/tts_agent.py | 30 +++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/rai_tts/rai_tts/agents/tts_agent.py b/src/rai_tts/rai_tts/agents/tts_agent.py index e62e2d124..02f6a68c0 100644 --- a/src/rai_tts/rai_tts/agents/tts_agent.py +++ b/src/rai_tts/rai_tts/agents/tts_agent.py @@ -17,6 +17,7 @@ from dataclasses import dataclass from threading import Event, Thread from typing import TYPE_CHECKING, Optional +from typing_extensions import Self from uuid import uuid4 from numpy._typing import NDArray @@ -35,6 +36,8 @@ from rai_interfaces.msg._hri_message import HRIMessage from rai_tts.models.base import TTSModel +from .initialization import load_config + if TYPE_CHECKING: from rai.communication.sound_device.api import SoundDeviceConfig @@ -119,6 +122,33 @@ def __init__( self.playback_data = PlayData() + @classmethod + def from_config(cls, cfg_path: Optional[str] = None) -> Self: + cfg = load_config(cfg_path) + config = SoundDeviceConfig( + stream=True, + is_output=True, + device_name=cfg.speaker.device_name, + ) + match cfg.text_to_speech.model_type: + case "ElevenLabs": + from rai_tts.models import ElevenLabsTTS + + if cfg.text_to_speech.voice != "": + model = ElevenLabsTTS(voice=cfg.text_to_speech.voice) + else: + raise ValueError("ElevenLabs [tts] vendor required voice to be set") + case "OpenTTS": + from rai_tts.models import OpenTTS + + if cfg.text_to_speech.voice != "": + model = OpenTTS(voice=cfg.text_to_speech.voice) + else: + model = OpenTTS() + case _: + raise ValueError(f"Unknown model_type: {cfg.text_to_speech.model_type}") + return cls(config, "rai_auto_tts", model) + def __call__(self): self.run() From 40f11736c9283620b4103bb12e3f9126fc51bb0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Tue, 29 Apr 2025 17:20:40 +0200 Subject: [PATCH 10/10] chore: pre-commit --- .pre-commit-config.yaml | 1 + src/rai_asr/rai_asr/__init__.py | 4 ++-- src/rai_asr/rai_asr/agents/__init__.py | 4 ++-- src/rai_asr/rai_asr/agents/asr_agent.py | 5 +++-- src/rai_asr/rai_asr/models/__init__.py | 4 ++-- src/rai_core/rai/frontend/configurator.py | 4 +--- src/rai_tts/rai_tts/__init__.py | 7 +++---- src/rai_tts/rai_tts/agents/__init__.py | 4 ++-- src/rai_tts/rai_tts/agents/initialization.py | 2 +- src/rai_tts/rai_tts/agents/tts_agent.py | 2 +- 10 files changed, 18 insertions(+), 19 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ca94af87b..08b8759d2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,4 +28,5 @@ repos: hooks: - id: ruff args: [--extend-select, "I,RUF022", --fix, --ignore, E731] + exclude: src/rai_core/rai/frontend/configurator.py - id: ruff-format diff --git a/src/rai_asr/rai_asr/__init__.py b/src/rai_asr/rai_asr/__init__.py index 911a22087..d0e8691ae 100644 --- a/src/rai_asr/rai_asr/__init__.py +++ b/src/rai_asr/rai_asr/__init__.py @@ -19,16 +19,17 @@ from rai_asr.agents.asr_agent import SpeechRecognitionAgent from rai_asr.agents.initialization import ( + TRANSCRIBE_MODELS, ASRAgentConfig, MicrophoneConfig, TranscribeConfig, VADConfig, WWConfig, load_config, - TRANSCRIBE_MODELS, ) __all__ = [ + "TRANSCRIBE_MODELS", "ASRAgentConfig", "MicrophoneConfig", "SpeechRecognitionAgent", @@ -36,5 +37,4 @@ "VADConfig", "WWConfig", "load_config", - "TRANSCRIBE_MODELS", ] diff --git a/src/rai_asr/rai_asr/agents/__init__.py b/src/rai_asr/rai_asr/agents/__init__.py index 164c0d819..27abe7179 100644 --- a/src/rai_asr/rai_asr/agents/__init__.py +++ b/src/rai_asr/rai_asr/agents/__init__.py @@ -14,16 +14,17 @@ from rai_asr.agents.asr_agent import SpeechRecognitionAgent from rai_asr.agents.initialization import ( + TRANSCRIBE_MODELS, ASRAgentConfig, MicrophoneConfig, TranscribeConfig, VADConfig, WWConfig, load_config, - TRANSCRIBE_MODELS, ) __all__ = [ + "TRANSCRIBE_MODELS", "ASRAgentConfig", "MicrophoneConfig", "SpeechRecognitionAgent", @@ -31,5 +32,4 @@ "VADConfig", "WWConfig", "load_config", - "TRANSCRIBE_MODELS", ] diff --git a/src/rai_asr/rai_asr/agents/asr_agent.py b/src/rai_asr/rai_asr/agents/asr_agent.py index 79ec336d5..98390996a 100644 --- a/src/rai_asr/rai_asr/agents/asr_agent.py +++ b/src/rai_asr/rai_asr/agents/asr_agent.py @@ -17,7 +17,6 @@ import time from threading import Event, Lock, Thread from typing import Any, List, Optional, TypedDict -from typing_extensions import Self from uuid import uuid4 import numpy as np @@ -34,9 +33,11 @@ SoundDeviceConnector, SoundDeviceMessage, ) +from typing_extensions import Self from rai_asr.models import BaseTranscriptionModel, BaseVoiceDetectionModel -from .initialization import ASRAgentConfig, load_config + +from .initialization import load_config class ThreadData(TypedDict): diff --git a/src/rai_asr/rai_asr/models/__init__.py b/src/rai_asr/rai_asr/models/__init__.py index 84f9bccd7..e0187559f 100644 --- a/src/rai_asr/rai_asr/models/__init__.py +++ b/src/rai_asr/rai_asr/models/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. from rai_asr.models.base import BaseTranscriptionModel, BaseVoiceDetectionModel -from rai_asr.models.local_whisper import LocalWhisper, FasterWhisper +from rai_asr.models.local_whisper import FasterWhisper, LocalWhisper from rai_asr.models.open_ai_whisper import OpenAIWhisper from rai_asr.models.open_wake_word import OpenWakeWord from rai_asr.models.silero_vad import SileroVAD @@ -21,8 +21,8 @@ __all__ = [ "BaseTranscriptionModel", "BaseVoiceDetectionModel", - "LocalWhisper", "FasterWhisper", + "LocalWhisper", "OpenAIWhisper", "OpenWakeWord", "SileroVAD", diff --git a/src/rai_core/rai/frontend/configurator.py b/src/rai_core/rai/frontend/configurator.py index e6d483f9a..d5d71ea5c 100644 --- a/src/rai_core/rai/frontend/configurator.py +++ b/src/rai_core/rai/frontend/configurator.py @@ -27,8 +27,6 @@ from langchain_ollama import ChatOllama, OllamaEmbeddings from langchain_openai import ChatOpenAI, OpenAIEmbeddings -from rai.communication import sound_device - def get_sound_devices( reinitialize: bool = False, output: bool = False @@ -948,7 +946,7 @@ def setup_steps(): step_names.append("🔊 Text to Speech") step_render.append(tts) - except ImportError as e: + except ImportError: pass step_names.extend( diff --git a/src/rai_tts/rai_tts/__init__.py b/src/rai_tts/rai_tts/__init__.py index 3ddb1bdde..2581f13a6 100644 --- a/src/rai_tts/rai_tts/__init__.py +++ b/src/rai_tts/rai_tts/__init__.py @@ -12,16 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .agents import TextToSpeechAgent, TTS_MODELS, load_config, TTSAgentConfig +from .agents import TTS_MODELS, TextToSpeechAgent, TTSAgentConfig, load_config from .models import ElevenLabsTTS, OpenTTS - __all__ = [ + "TTS_MODELS", "ElevenLabsTTS", "OpenTTS", + "TTSAgentConfig", "TextToSpeechAgent", "TextToSpeechAgent", - "TTSAgentConfig", "load_config", - "TTS_MODELS", ] diff --git a/src/rai_tts/rai_tts/agents/__init__.py b/src/rai_tts/rai_tts/agents/__init__.py index 33ac95fc2..6805fb87b 100644 --- a/src/rai_tts/rai_tts/agents/__init__.py +++ b/src/rai_tts/rai_tts/agents/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from rai_tts.agents.initialization import TTS_MODELS, TTSAgentConfig, load_config from rai_tts.agents.tts_agent import TextToSpeechAgent -from rai_tts.agents.initialization import TTS_MODELS, load_config, TTSAgentConfig -__all__ = ["TextToSpeechAgent", "TTSAgentConfig", "load_config", "TTS_MODELS"] +__all__ = ["TTS_MODELS", "TTSAgentConfig", "TextToSpeechAgent", "load_config"] diff --git a/src/rai_tts/rai_tts/agents/initialization.py b/src/rai_tts/rai_tts/agents/initialization.py index 376ba7dda..870e75b95 100644 --- a/src/rai_tts/rai_tts/agents/initialization.py +++ b/src/rai_tts/rai_tts/agents/initialization.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Literal, Optional +from typing import Optional import tomli diff --git a/src/rai_tts/rai_tts/agents/tts_agent.py b/src/rai_tts/rai_tts/agents/tts_agent.py index 02f6a68c0..99f8fd4a6 100644 --- a/src/rai_tts/rai_tts/agents/tts_agent.py +++ b/src/rai_tts/rai_tts/agents/tts_agent.py @@ -17,7 +17,6 @@ from dataclasses import dataclass from threading import Event, Thread from typing import TYPE_CHECKING, Optional -from typing_extensions import Self from uuid import uuid4 from numpy._typing import NDArray @@ -32,6 +31,7 @@ from rai.communication.sound_device import SoundDeviceConfig, SoundDeviceConnector from rai.communication.sound_device.connector import SoundDeviceMessage from std_msgs.msg import String +from typing_extensions import Self from rai_interfaces.msg._hri_message import HRIMessage from rai_tts.models.base import TTSModel