#!/usr/bin/env python3
import time
import queue
import threading
import logging
import traceback
from nano_llm.web import WebServer
from nano_llm.utils import AttributeDict, inspect_function, json_type, python_type
[docs]
class Plugin(threading.Thread):
    """
    Base class for plugins that process incoming/outgoing data from connections
    with other plugins, forming a pipeline or graph.  Plugins can run either
    single-threaded or in an independent thread that processes data out of a queue.
    Frequent categories of plugins:
    
      * sources:  text prompts, images/video
      * process:  LLM queries, RAG, dynamic LLM calls, image post-processors
      * outputs:  print to stdout, save images/video
      
    Inherited plugins should implement the :func:`process` function to handle incoming data.
    """
    Instances = []  #: Global list of plugin instances
        
    def __init__(self, name=None, title=None, inputs=1, outputs=1,
                 relay=False, drop_inputs=False, threaded=True, **kwargs):
        """
        Base initializer for Plugin implementations.
        
        Args:
            name (str):  specify the name of this plugin instance (otherwise initialized from class name)
            output_channels (int): the number of sets of output connections the plugin has
            relay (bool): if true, will relay any inputs as outputs after processing
            drop_inputs (bool): if true, only the most recent input in the queue will be used
            threaded (bool): if true, will spawn independent thread for processing the queue.
        """
        super().__init__(daemon=True)
        self.name = name if name else self.__class__.__name__
        self.title = title
        self.relay = relay
        self.drop_inputs = drop_inputs
        self.threaded = threaded
        self.stop_flag = False
        self.interrupted = False
        self.processing = False
        self.parameters = {}
        self.tools = {}
        
        inputs = kwargs.get('input_channels', inputs)
        outputs = kwargs.get('output_channels', outputs)
        
        if isinstance(inputs, str):
            inputs = [inputs]
            
        if isinstance(inputs, list):
            self.input_names = inputs
            inputs = len(inputs)
        else:
            self.input_names = ['0']
        if isinstance(outputs, str):
            outputs = [outputs]
            
        if isinstance(outputs, list):
            self.output_names = outputs
            outputs = len(outputs)
        elif isinstance(outputs, int):
            self.output_names = [str(x) for x in range(outputs)]
        elif outputs is None:
            self.output_names = [] #['0']
            outputs = 0
        else:
            raise TypeError(f"outputs should have been int, str, list[str], or None  (was {type(outputs)})")
        self.outputs = [[] for i in range(outputs)]
        self.add_parameter('layout_grid', type=dict, default={}, hidden=True)
        self.add_parameter('layout_node', type=dict, default={}, hidden=True)
        
        if threaded:
            self.input_queue = queue.Queue()
            self.input_event = threading.Event()
        from nano_llm import BotFunctions
        self.BotFunctions = BotFunctions
        
        Plugin.Instances.append(self)
     
    def __del__(self):
        """
        Stop the plugin from running and unregister it.
        """
        self.destroy()
[docs]
    def process(self, input, **kwargs):
        """
        Abstract function that plugin instances should implement to process incoming data.
        Don't call this function externally unless ``threaded=False``, because
        otherwise the plugin's internal thread dispatches from the queue.
        Args:
        
          input: input data to process from the previous plugin in the pipeline
          kwargs: optional processing arguments that accompany this data
          
        Returns:
        
          Plugins should return their output data to be sent to downstream plugins.
          You can also call :func:`output()` as opposed to returning it.
        """
        logging.warning(f"plugin {self.name} did not implement process() - dropping input") 
    
[docs]
    def connect(self, plugin, channel=0, **kwargs):
        """
        Connect the output queue from this plugin with the input queue of another plugin,
        so that this plugin sends its output data to the other one.
        
        Args:
        
          plugin (Plugin|callable): either the plugin to link to, or a callback function.
          channel (int) -- the output channel of this plugin to link the other plugin to.
                        
        Returns:
          A reference to this plugin instance (self)
        """
        from nano_llm.plugins import Callback
        
        if not isinstance(plugin, Plugin):
            if not callable(plugin):
                raise TypeError(f"{type(self)}.connect() expects either a Plugin instance or a callable function (was {type(plugin)})")
            plugin = Callback(plugin, **kwargs)
            
        self.outputs[channel].append(plugin)
        
        if isinstance(plugin, Callback):
            logging.debug(f"connected {self.name} to {plugin.function.__name__} on channel={channel}")  # TODO https://stackoverflow.com/a/25959545
        else:
            logging.debug(f"connected {self.name} to {plugin.name} on channel={channel}")
            
        return self 
        
[docs]
    def add(self, plugin, channel=0, **kwargs):
        """
        @deprecated Please use :func:``Plugin.connect``
        """
        return self.connect(plugin, channel=channel, **kwargs) 
    
[docs]
    def __call__(self, input=None, **kwargs):
        """
        Callable () operator alias for the :func:`input()` function.
        This is provided for a more intuitive way of processing data 
        like ``plugin(data)`` instead of ``plugin.input(data)``
        
        Args:
        
          input: input data sent to the plugin's :func:`process()` function.
          kwargs: additional arguments forwarded to the plugin's :func:`process()` function.
          
        Returns:
        
          None if the plugin is threaded, otherwise returns any outputs.
        """
        return self.input(input, **kwargs) 
        
            
[docs]
    def output(self, output, channel=0, **kwargs):
        """
        Output data to the next plugin(s) on the specified channel (-1 for all channels)
        """
        #if output is None:
        #    return
        if channel >= 0:
            kwargs.update(dict(sender=self, channel=channel))
            for output_plugin in self.outputs[channel]:
                output_plugin.input(output, **kwargs)
        else:
            for output_channel in self.outputs:
                kwargs.update(dict(sender=self, channel=output_channel))
                for output_plugin in output_channel:
                    output_plugin.input(output, **kwargs)
                    
        return output 
     
    @property
    def num_outputs(self):
        """
        Return the total number of output connections across all channels
        """
        count = 0
        for output_channel in self.outputs:
            count += len(output_channel) 
        return count
        
[docs]
    def start(self):
        """
        Start threads for all plugins in the graph that have threading enabled.
        """
        if self.threaded:
            if not self.is_alive():
                super().start()
            
        for output_channel in self.outputs:
            for output in output_channel:
                output.start()
                
        return self 
     
[docs]
    def stop(self):
        """
        Flag the plugin to stop processing and exit the run() thread.
        """
        self.stop_flag = True
        logging.debug(f"stopping plugin {self.name} (thread {self.native_id})") 
     
[docs]
    def destroy(self):
        """
        Stop a plugin thread's running, and unregister it from the global instances.
        """ 
        self.stop()
        
        try:
            Plugin.Instances.remove(self)
        except ValueError:
            logging.warning(f"Plugin {getattr(self, 'name', '')} wasn't in global instances list when being deleted") 
            
[docs]
    def run(self):
        """
        Processes the queue forever and automatically run when created with ``threaded=True``
        """
        while not self.stop_flag:
            try:
                if not self.input_event.wait(timeout=0.25):
                    continue
                    
                self.input_event.clear()
                
                while not self.stop_flag:
                    try:
                        input, kwargs = self.input_queue.get(block=False)
                        self.dispatch(input, **kwargs)
                    except queue.Empty:
                        break
            except Exception as error:
                logging.error(f"Exception occurred during processing of {self.name}\n\n{traceback.format_exc()}")
        logging.debug(f"{self.name} plugin stopped (thread {self.native_id})") 
[docs]
    def dispatch(self, input, **kwargs):
        """
        Invoke the process() function on incoming data
        """
        if self.interrupted:
            #logging.debug(f"{type(self)} resetting interrupted flag to false")
            self.interrupted = False
          
        self.processing = True
        outputs = self.process(input, **kwargs)
        self.processing = False
        if outputs is not None:
            self.output(outputs)
        
        if self.relay:
            self.output(input)
            
        return outputs 
   
[docs]
    def interrupt(self, clear_inputs=True, recursive=True, block=None):
        """
        Interrupt any ongoing/pending processing, and optionally clear the input queue
        along with any downstream queues, and optionally wait for processing of those 
        requests to have finished.
        
        Args:
        
          clear_inputs (bool):  if True, clear any remaining inputs in this plugin's queue.
          recursive (bool):  if True, then any downstream plugins will also be interrupted.
          block (bool):  is true, this function will wait until any ongoing processing has finished.
                         This is done so that any lingering outputs don't cascade downstream in the pipeline.
                         If block is None, it will automatically be set to true if this plugin has outputs.
        """
        #logging.debug(f"interrupting plugin {type(self)}  clear_inputs={clear_inputs} recursive={recursive} block={block}")
        
        if clear_inputs:
            self.clear_inputs()
          
        self.interrupted = True
        
        num_outputs = self.num_outputs
        block_other = block
        
        if block is None and num_outputs > 0:
            block = True
            
        while block and self.processing:
            #logging.debug(f"interrupt waiting for {type(self)} to complete processing")
            time.sleep(0.01) # TODO use an event for this?
        
        if recursive and num_outputs > 0:
            for output_channel in self.outputs:
                for output in output_channel:
                    output.interrupt(clear_inputs=clear_inputs, recursive=recursive, block=block_other) 
                    
[docs]
    def find(self, type):
        """
        Return the plugin with the specified type by searching for it among
        the pipeline graph of inputs and output connections to other plugins.
        """
        if isinstance(self, type):
            return self
            
        for output_channel in self.outputs:
            for output in output_channel:
                if isinstance(output, type):
                    return output
                plugin = output.find(type)
                if plugin is not None:
                    return plugin
            
        return None 
        
[docs]
    def add_parameter(self, attribute: str, name=None, type=None, default=None,
                      read_only=False, hidden=False, help=None, kwarg=None, end=None, **kwargs):
        """
        Make an attribute that is shared in the state_dict and can be accessed/modified by clients.
        These will automatically show up in the studio web UI and can be sync'd over websockets.
        If there is an __init__ parameter by the same name, its help docs will be taken from that.
        """
        if not kwarg:
            kwarg = attribute
            
        init = inspect_function(self.__init__)['parameters'].get(kwarg, {})
        
        if not read_only: #if not hasattr(self, attribute):
            setattr(self, attribute, default)
            
        if name is None:
            name = attribute.replace('_', ' ').title()
        
        if type is None:
            type = init.get('type')
        else:
            type = json_type(type)
            
        param = {
            'display_name': name,
            'type': type,
            'read_only': read_only,
            'hidden': hidden,
        }
        
        if hasattr(self, 'type_hints'):
            for key, value in self.type_hints().items():
                if key == attribute:
                    param.update(value)
                    
        #if kwarg:
        #    param['kwarg'] = kwarg
        
        if not help:
            help = init.get('help')
        
        if help:
            param['help'] = help
            
        if default:
            param['default'] = default
   
        if end:
            param['end'] = end
        param.update(**kwargs)
        self.parameters[attribute] = param
        return param 
    
[docs]
    def add_parameters(self, **kwargs):
        """
        Add parameters from kwargs of the form ``Plugin.add_parameters(x=x, y=y)`` 
        where the keys are the attribute names and the values are the default values.
        """
        for key, value in kwargs.items():
            self.add_parameter(key, default=value) 
            
[docs]
    def set_parameters(self, **kwargs):
        """
        Set a state dict of parameters. Only entries in the dict matching a parameter will be set.
        """
        for attr, value in kwargs.items():
            if attr not in self.parameters:
                if attr != 'name' and attr != 'type' and attr != 'connections':
                    logging.warning(f"attempted to set unknown parameter {self.name}.{attr}={value} (skipping)")
                continue
            logging.debug(f"{self.name} setting parameter '{attr}' to {value}")
            if self.parameters[attr]['type'] == 'boolean' and isinstance(value, str):
                value = value.lower()
                if value == 'true' or value == '1':
                    value = True
                else:
                    value = False
            setattr(self, attr, value) 
    
[docs]
    def reorder_parameters(self):
        """
        Move some parameters to the end for display purposes (if end=True)
        """
        if hasattr(self, '_reordered_parameters') and self._reordered_parameters:
            return
            
        params = self.parameters.copy()
        
        for param_name, param in params.items():
            if 'end' in param:
                del param['end']
                del self.parameters[param_name]
                self.parameters[param_name] = param
        
        self._reordered_parameters = True 
                    
[docs]
    def state_dict(self, config=False, connections=False, hidden=False, **kwargs):
        """
        Return a configuration dict with plugin state that gets shared with clients. 
        Subclasses can reimplement this to add custom state for each type of plugin.
        """
        state = {
            'name': self.name,
            'type': self.__class__.__name__,
        }
        
        if config:
            connections = True
            
        if connections:
            connections = []
            
            for c, output_channel in enumerate(self.outputs):
                for output in output_channel:
                    connections.append({
                        'to': output.name,
                        'input': 0,
                        'output': c
                     })
            
            state['connections'] = connections
   
        if config:
            self.reorder_parameters()
            
            state.update({
                'title': self.title if self.title else self.name,
                'inputs': self.input_names,
                'outputs': self.output_names,
                'parameters': self.parameters,
            })
        
        for attr, param in self.parameters.items():
            if hidden or not param['hidden'] or config:
                state[attr] = getattr(self, attr)
            
        return state 
               
[docs]
    def send_state(self, state_dict=None, **kwargs):
        """
        Send the state dict message over the websocket.
        """
        if not WebServer.Instance or not WebServer.Instance.connected:
            logging.warning(f"plugin {self.name} had no webserver or connected clients to send state_dict")
            return
            
        if state_dict is None:
            state_dict = self.state_dict(**kwargs)
            
        WebServer.Instance.send_message({
            'state_dict': {self.name: state_dict}
        }) 
     
[docs]
    def send_stats(self, stats={}, **kwargs):
        """
        Send performance stats over the websocket.
        """
        if not WebServer.Instance or not WebServer.Instance.connected or (not stats and not kwargs):
            return
       
        stats.update(kwargs)
        WebServer.Instance.send_message({
            'stats': {self.name: stats}
        }) 
           
[docs]
    def send_alert(self, message, **kwargs):
        """
        Send an alert message to the webserver (see WebServer.send_alert() for kwargs)
        """
        if not WebServer.Instance or not WebServer.Instance.connected:
            return
            
        return WebServer.Instance.send_alert(message, **kwargs) 
     
[docs]
    def send_client_output(self, channel):
        """
        Subscribe clients to recieving plugin output over websockets.
        """
        from nano_llm.plugins import WebClient
        
        for plugin in self.outputs[channel]:
            if isinstance(plugin, WebClient):
                return
               
        web_client = WebClient()
        web_client.start()
         
        self.connect(web_client, channel=channel) 
[docs]
    def apply_substitutions(self, text):
        """
        Perform variable substitution on a string of text by looking up values from other plugins.
        
        References can be scoped to a plugin's name:  "The date is ${Clock.date}"
        Or if left unscoped, find the first plugin with it:  "The date is ${date}"
        Both plugins and attributes are case-insensitive:  "The date is ${DATE}"
        
        These can also refer to getter functions or properties that require no positional arguments,
        and if found the associate function will be called and its return value substituted instead.
        """
        def find_closing_bracket(s : str):
            for i, c in enumerate(s):
                if c == '}':
                    return i
                elif c.isalnum() or c == '_' or c == '.':
                    continue
         
        def read_param(var : str):
            period = var.find('.')
            
            if period > 0 and period < len(var) - 1:
                plugin_name = var[:period].lower()
                plugin_attr = var[period+1:]
            else:
                plugin_name = None
                plugin_attr = var
            
            bot_functions = self.BotFunctions()
            for plugin in [self] + Plugin.Instances + self.BotFunctions():  # resolve unclassed references to this plugin first
                if plugin_name and plugin_name != plugin.name.lower():
                    continue
                    
                if not hasattr(plugin, plugin_attr):
                    plugin_attr_lower = plugin_attr.lower()
                    if hasattr(plugin, plugin_attr_lower):
                        plugin_attr = plugin_attr_lower
                    elif hasattr(plugin, 'function') and getattr(plugin, 'name', '').lower() == plugin_attr_lower:
                        return str(plugin.function())
                    else:
                        continue
                        
                value = getattr(plugin, plugin_attr)
                
                if callable(value):
                    value = value()
                    
                return str(value)
          
            logging.warning(f"{self.name} could not find variable ${{{var}}} for substitution")
            return f"${{{var}}}"
            
        #while True:
        splits = text.split('${')
        string = ''
    
        if len(splits) <= 1:
            return text
        
        for split in splits:
            if not split:
                continue
                
            end = find_closing_bracket(split)
            
            if end is None:
                string = string + split
                continue
                
            var = split[:end].strip()
            string = string + read_param(var) 
        
            if end < len(split)-1:
                string = string + split[end+1:]
                
        return string
        '''                
            if text != string:
                text = string
                continue
            else:        
                return string  
        '''