#!/usr/bin/env python3
from nano_llm import Agent, Pipeline
from nano_llm.utils import ArgParser, print_table
from nano_llm.plugins import (
    UserPrompt, ChatQuery, PrintStream, 
    AutoASR, AutoTTS, VADFilter, RateLimit,
    ProcessProxy, AudioOutputDevice, AudioRecorder
)
[docs]
class VoiceChat(Agent):
    """
    Agent for ASR → LLM → TTS pipeline.
    """
[docs]
    def __init__(self, asr=None, llm=None, tts=None, **kwargs):
        """
        Args:
          asr (NanoLLM.plugins.AutoASR|str): the ASR plugin instance or model name to connect with the LLM.
          llm (NanoLLM.Plugin|str): The LLM model plugin instance (like ChatQuery) or model name.
          tts (NanoLLM.plugins.AutoTTS|str): the TTS plugin instance (or model name)- if None, will be loaded from kwargs.
        """
        super().__init__(**kwargs)
        #: The LLM model plugin (like ChatQuery)
        if isinstance(llm, str):
            kwargs['model'] = llm
            
        if not llm or isinstance(llm, str):
            self.llm = ChatQuery(**kwargs) #ProcessProxy('ChatQuery', **kwargs)  
        else:
            self.llm = llm
            
        self.llm.add(PrintStream(color='green'))
        
        #: The ASR plugin whose output is connected to the LLM.
        if not asr or isinstance(asr, str):
            self.asr = AutoASR.from_pretrained(asr=asr, **kwargs) 
        else:
            self.asr = asr
            
        self.vad = VADFilter(**kwargs).add(self.asr) if self.asr else None
        
        if self.asr:
            self.asr.add(PrintStream(partial=False, prefix='## ', color='blue'), AutoASR.OutputFinal)
            self.asr.add(PrintStream(partial=False, prefix='>> ', color='magenta'), AutoASR.OutputPartial)
            
            self.asr.add(self.asr_partial, AutoASR.OutputPartial) # pause output when user is speaking
            self.asr.add(self.asr_final, AutoASR.OutputFinal)     # clear queues on final ASR transcript
            self.asr.add(self.llm, AutoASR.OutputFinal)  # runs after asr_final() and any interruptions occur
            
            self.asr_history = None  # store the partial ASR transcript
        #: The TTS plugin that speaks the LLM output.
        if not tts or isinstance(tts, str):
            self.tts = AutoTTS.from_pretrained(tts=tts, **kwargs) 
        else:
            self.tts = tts
            
        if self.tts:
            self.tts_output = RateLimit(rate=1.0, chunk=9600) # slow down TTS to realtime and be able to pause it
            self.tts.add(self.tts_output)
            self.llm.add(self.tts, ChatQuery.OutputWords)
            self.audio_output_device = kwargs.get('audio_output_device')
            self.audio_output_file = kwargs.get('audio_output_file')
            
            if self.audio_output_device is not None:
                self.audio_output_device = AudioOutputDevice(**kwargs)
                self.tts_output.add(self.audio_output_device)
            
            if self.audio_output_file is not None:
                self.audio_output_file = AudioRecorder(**kwargs)
                self.tts_output.add(self.audio_output_file)
        
        #: Text prompts from web UI or CLI.
        self.prompt = UserPrompt(interactive=True, **kwargs)
        self.prompt.add(self.llm)
        
        # setup pipeline with two entry nodes
        self.pipeline = [self.prompt]
        if self.vad:
            self.pipeline.append(self.vad) 
            
[docs]
    def asr_partial(self, text):
        """
        Callback that occurs when the ASR has a partial transcript (while the user is speaking).
        These partial transcripts get revised mid-stream until the user finishes their phrase.
        This is also used for pausing/interrupting the bot output for when the user starts speaking.
        """
        self.asr_history = text
        if len(text.split(' ')) < 2:
            return
        if self.tts:
            self.tts_output.pause(1.0) 
[docs]
    def asr_final(self, text):
        """
        Callback that occurs when the ASR outputs when there is a pause in the user talking,
        like at the end of a sentence or paragraph.  This will interrupt/cancel any ongoing bot output.
        """
        self.asr_history = None
        self.on_interrupt() 
        
[docs]
    def on_interrupt(self):
        """
        Interrupt/cancel the bot output when the user submits (or speaks) a full query.
        """
        self.llm.interrupt(recursive=False)
        if self.tts:
            self.tts.interrupt(recursive=False)
            self.tts_output.interrupt(block=False, recursive=False) # might be paused/asleep 
 
 
if __name__ == "__main__":
    parser = ArgParser(extras=ArgParser.Defaults+['asr', 'tts', 'audio_output'])
    args = parser.parse_args()
    
    agent = VoiceChat(**vars(args)).run()