#!/usr/bin/env python3
import logging
import datetime
import torch
import numpy as np
from ..utils import ImageExtensions, ImageTypes, print_table
[docs]
class ChatMessage():
    """
    Create a chat entry consisting of a text message, image, ect as input.  
    
    Args:
    
      role (str): The chat's turn template to apply, typically 'user' or 'bot'.
                  The role should have a corresponding entry in the active ChatTemplate.
       
      text (str): String containing the message's content for text messages.
      
      image (str|image): Either a np.ndarray, torch.Tensor, cudaImage, PIL.Image,
                         or a path to an image file (.jpg, .png, .bmp, ect)
      kwargs: For messages with alternate content types, pass them in via kwargs
              and they will automatically be determined like so::
       
                 message = ChatMessage(role='user', audio='sounds.wav')
                 
              There are additional lower-level kwargs that can be set below.
              
      use_cache (bool): cache the tokens/embeddings for reused prompts (defaults to false)
      tokens (list[int] or np.ndarray): the message contents already having been tokenized
      embedding (np.ndarray): the message contents already having been embedded
      history (ChatHistory): the ChatHistory object this message belongs to
    """    
    def __init__(self, role='user', text=None, image=None, **kwargs):
    
        #: The content or media contained in the message
        self.content = None
        
        #: The type of the message ('text', 'image', 'audio', ect)
        self.type = None
        
        #: The user role or character ('user', 'assistant', 'system', ect)
        self.role = role
        
        #: The version of this message with the role template applied
        self.template = None
        
        #: The tokenized version of the message
        self.tokens = kwargs.get('tokens', None)
        
        #: The embedding of the message
        self.embedding = kwargs.get('embedding', None)
        
        #: The ChatHistory object this message belongs to
        self.history = kwargs.get('history', None)
        
        #: Set to true if the tokens/embeddings should be cached for reused prompts
        self.use_cache = kwargs.get('use_cache', False)
        
        #: Set to true if the message is already in the chat embedding
        self.cached = kwargs.get('cached', self.tokens or self.embedding)
        
        #: The index of this message in the chat history
        self.index = None
        
        #: The previous message in the chat history
        self.prev = None
        
        #: The next message in the chat history
        self.next = None
        
        # Determine the message type
        if text is not None:
            self.content = text
            self.type = 'text'
        elif image is not None:
            self.content = image
            self.type = 'image'
        else:
            for key, value in kwargs.items():
                content_type = self.content_type(value)
                if content_type:
                    self.type = content_type
                    self.content = value
                    break
                    
            if self.type is None:
                raise ValueError(f"couldn't find valid message content in {kwargs}, please specify its type")
        # Apply variable substitutions
        #self.apply_substitutions(kwargs.get('substitutions'))
        
    @property
    def num_tokens(self):
        """
        Return the number of tokens used by this message.
        embed() needs to have been called for this to be valid.
        """
        if self.tokens is not None:
            if isinstance(self.tokens, (np.ndarray, torch.Tensor)):
                return self.tokens.shape[1]
            elif isinstance(self.tokens, list):
                return len(self.tokens)
            else:
                raise TypeError(f"ChatMessage had tokens with invalid type ({type(self.tokens)})")
        elif self.embedding is not None:
            return self.embedding.shape[1]
        else:
            return 0
    @property
    def start_token(self):
        """
        The token offset or position in the chat history at which this message begins.
        """
        offset = 0
        
        for i in range(0, self.index):
            offset += self.history[i].num_tokens
            
        return offset
        
[docs]
    @staticmethod
    def content_type(content):
        """
        Try to automatically determine the message content type.
        """
        if isinstance(content, str):
            if content.endswith(ImageExtensions):
                return 'image'
            else:
                return "text" 
        elif isinstance(content, ImageTypes):
            return 'image'
        else:
            return None 
    
[docs]
    def is_type(self, type):
        """
        Return true if the message is of the given type (like 'text', 'image', ect)
        """
        return (self.type == type) 
    
    '''
    def apply_substitutions(self, substitutions=None):
        """
        Apply variable substitutions to the message content, like "Today's date is ${DATE}".
        This is separate from the templating that occurs with the special tokens & separators.
        """
        if self.type != 'text' or self.cached or substitutions is False:
            return
            
        if isinstance(substitutions, dict):
            for key, value in substitutions.items():
                self.content = self.content.replace(key, value)
            return
            
        if "${DATE}" in self.content:
            self.content = self.content.replace("${DATE}", datetime.date.today().strftime("%Y-%m-%d"))
            
        if "${TIME}" in self.content:
            self.content = self.content.replace("${TIME}", datetime.datetime.now().strftime("%-I:%M %p"))
           
        if "${TOOLS}" in self.content:
            from nano_llm import BotFunctions
            self.content = self.content.replace("${TOOLS}", BotFunctions.generate_docs(spec=self.history.template.tool_spec))
          
        if "${LOCATION}" in self.content:
            from nano_llm.plugins.bot_functions.location import LOCATION
            self.content = self.content.replace("${LOCATION}", LOCATION())
    '''
              
[docs]
    def embed(self, return_tensors='np', **kwargs):
        """
        Apply message templates, tokenization, and generate the embedding.
        """
        if self.embedding is not None:
            return self.embedding
            
        if self.tokens is not None and not self.history.model.has_embed:
            if isinstance(self.tokens, list):
                self.tokens = np.expand_dims(np.asarray(self.tokens, dtype=np.int32), axis=0)
            return self.tokens
          
        if self.history is None:
            raise RuntimeError("this message needs to be added to a ChatHistory before embed() is called")
             
        # lookup the role template to apply
        first_msg = 1 if 'system' in self.history.template else 0
        role = 'first' if 'first' in self.history.template and self.index == first_msg else self.role
        if role not in self.history.template:
            raise RuntimeError(f"chat template {self.history.template.get('name', '')} didn't have a role defined for '{entry.role}' (had keys: {self.history.template.keys()})")
         
        # extract template prefix/postfix
        template = self.history.template[role]
        split_template = template.split('${MESSAGE}')
        
        if len(split_template) == 1:  # there was no ${MESSAGE}
            split_template.append('')
        if self.prev and self.prev.role == self.role:
            split_template[0] = ''
            
        if self.next and self.next.role == self.role:
            split_template[1] = ''
         
        # embed based on media type
        if self.type == 'text':
            self._embed_text(self.history.model, split_template, return_tensors=return_tensors, **kwargs)
        elif self.type == 'image':
            self._embed_image(self.history.model, split_template, return_tensors=return_tensors, **kwargs)
            
        # mark as cached
        self.cached = True
        
        if self.embedding is not None:
            return self.embedding
            
        if self.tokens is not None:
            return self.tokens 
    def _embed_text(self, model, template, return_tensors='np', **kwargs):
        """
        Generate the token embeddings for a text message.
        """
        self.template = template[0] + self.content + template[1]
        
        if self.tokens is not None:
            if model.has_embed:
                self.embedding = model.embed_tokens(self.tokens, return_tensors=return_tensors, **kwargs)
        else:     
            self.embedding, self.tokens = model.embed_text(
                self.template, use_cache=self.use_cache,
                return_tensors=return_tensors, return_tokens=True,
                **kwargs
            )
    
    def _embed_image(self, model, template, return_tensors='np', **kwargs):
        """
        Generate the encoded vision embeddings for an image.
        """
        if not model.has_vision:
            raise RuntimeError(f"attempted to embed an image in the chat, but '{model.config.name}' was not a multimodal vision model")
        # add the template prefix                
        embeddings = [] 
        if template[0]:
            embeddings.append(model.embed_text(template[0], use_cache=True, return_tensors=return_tensors))
            
        # encode the image
        image_outputs = model.embed_image(self.content, return_tensors=return_tensors, return_dict=True)
        self.history.image_embedding = image_outputs.image_embeds # save the unprojected embeddings for RAG
        embeddings.append(image_outputs.embedding)
        
        # add the template trailer
        template[1] = '\n' + template[1]
        
        if template[1]:
            embeddings.append(model.embed_text(template[1], use_cache=True, return_tensors=return_tensors))
                
        # concatenate all embeddings
        self.embedding = np.concatenate(embeddings, axis=1)
        
        if self.history.print_stats:
            print_table(model.vision.stats) 
            
        #logging.debug(f"chat embed image  shape={self.embedding.shape}  dtype={self.embedding.dtype}  template={template}")