Source code for nano_llm.agents.voice_chat

#!/usr/bin/env python3
from nano_llm import Agent, Pipeline
from nano_llm.utils import ArgParser, print_table

from nano_llm.plugins import (
    UserPrompt, ChatQuery, PrintStream, 
    AutoASR, AutoTTS, VADFilter, RateLimit,
    ProcessProxy, AudioOutputDevice, AudioRecorder
)


[docs] class VoiceChat(Agent): """ Agent for ASR → LLM → TTS pipeline. """
[docs] def __init__(self, asr=None, llm=None, tts=None, **kwargs): """ Args: asr (NanoLLM.plugins.AutoASR|str): the ASR plugin instance or model name to connect with the LLM. llm (NanoLLM.Plugin|str): The LLM model plugin instance (like ChatQuery) or model name. tts (NanoLLM.plugins.AutoTTS|str): the TTS plugin instance (or model name)- if None, will be loaded from kwargs. """ super().__init__(**kwargs) #: The LLM model plugin (like ChatQuery) if isinstance(llm, str): kwargs['model'] = llm if not llm or isinstance(llm, str): self.llm = ChatQuery(**kwargs) #ProcessProxy('ChatQuery', **kwargs) else: self.llm = llm self.llm.add(PrintStream(color='green')) #: The ASR plugin whose output is connected to the LLM. if not asr or isinstance(asr, str): self.asr = AutoASR.from_pretrained(asr=asr, **kwargs) else: self.asr = asr self.vad = VADFilter(**kwargs).add(self.asr) if self.asr else None if self.asr: self.asr.add(PrintStream(partial=False, prefix='## ', color='blue'), AutoASR.OutputFinal) self.asr.add(PrintStream(partial=False, prefix='>> ', color='magenta'), AutoASR.OutputPartial) self.asr.add(self.asr_partial, AutoASR.OutputPartial) # pause output when user is speaking self.asr.add(self.asr_final, AutoASR.OutputFinal) # clear queues on final ASR transcript self.asr.add(self.llm, AutoASR.OutputFinal) # runs after asr_final() and any interruptions occur self.asr_history = None # store the partial ASR transcript #: The TTS plugin that speaks the LLM output. if not tts or isinstance(tts, str): self.tts = AutoTTS.from_pretrained(tts=tts, **kwargs) else: self.tts = tts if self.tts: self.tts_output = RateLimit(rate=1.0, chunk=9600) # slow down TTS to realtime and be able to pause it self.tts.add(self.tts_output) self.llm.add(self.tts, ChatQuery.OutputWords) self.audio_output_device = kwargs.get('audio_output_device') self.audio_output_file = kwargs.get('audio_output_file') if self.audio_output_device is not None: self.audio_output_device = AudioOutputDevice(**kwargs) self.tts_output.add(self.audio_output_device) if self.audio_output_file is not None: self.audio_output_file = AudioRecorder(**kwargs) self.tts_output.add(self.audio_output_file) #: Text prompts from web UI or CLI. self.prompt = UserPrompt(interactive=True, **kwargs) self.prompt.add(self.llm) # setup pipeline with two entry nodes self.pipeline = [self.prompt] if self.vad: self.pipeline.append(self.vad)
[docs] def asr_partial(self, text): """ Callback that occurs when the ASR has a partial transcript (while the user is speaking). These partial transcripts get revised mid-stream until the user finishes their phrase. This is also used for pausing/interrupting the bot output for when the user starts speaking. """ self.asr_history = text if len(text.split(' ')) < 2: return if self.tts: self.tts_output.pause(1.0)
[docs] def asr_final(self, text): """ Callback that occurs when the ASR outputs when there is a pause in the user talking, like at the end of a sentence or paragraph. This will interrupt/cancel any ongoing bot output. """ self.asr_history = None self.on_interrupt()
[docs] def on_interrupt(self): """ Interrupt/cancel the bot output when the user submits (or speaks) a full query. """ self.llm.interrupt(recursive=False) if self.tts: self.tts.interrupt(recursive=False) self.tts_output.interrupt(block=False, recursive=False) # might be paused/asleep
if __name__ == "__main__": parser = ArgParser(extras=ArgParser.Defaults+['asr', 'tts', 'audio_output']) args = parser.parse_args() agent = VoiceChat(**vars(args)).run()