Source code for nano_llm.chat.message

#!/usr/bin/env python3
import logging
import datetime

import torch
import numpy as np

from ..utils import ImageExtensions, ImageTypes, print_table


[docs] class ChatMessage(): """ Create a chat entry consisting of a text message, image, ect as input. Args: role (str): The chat's turn template to apply, typically 'user' or 'bot'. The role should have a corresponding entry in the active ChatTemplate. text (str): String containing the message's content for text messages. image (str|image): Either a np.ndarray, torch.Tensor, cudaImage, PIL.Image, or a path to an image file (.jpg, .png, .bmp, ect) kwargs: For messages with alternate content types, pass them in via kwargs and they will automatically be determined like so:: message = ChatMessage(role='user', audio='sounds.wav') There are additional lower-level kwargs that can be set below. use_cache (bool): cache the tokens/embeddings for reused prompts (defaults to false) tokens (list[int] or np.ndarray): the message contents already having been tokenized embedding (np.ndarray): the message contents already having been embedded history (ChatHistory): the ChatHistory object this message belongs to """ def __init__(self, role='user', text=None, image=None, **kwargs): #: The content or media contained in the message self.content = None #: The type of the message ('text', 'image', 'audio', ect) self.type = None #: The user role or character ('user', 'assistant', 'system', ect) self.role = role #: The version of this message with the role template applied self.template = None #: The tokenized version of the message self.tokens = kwargs.get('tokens', None) #: The embedding of the message self.embedding = kwargs.get('embedding', None) #: The ChatHistory object this message belongs to self.history = kwargs.get('history', None) #: Set to true if the tokens/embeddings should be cached for reused prompts self.use_cache = kwargs.get('use_cache', False) #: Set to true if the message is already in the chat embedding self.cached = kwargs.get('cached', self.tokens or self.embedding) #: The index of this message in the chat history self.index = None #: The previous message in the chat history self.prev = None #: The next message in the chat history self.next = None # Determine the message type if text is not None: self.content = text self.type = 'text' elif image is not None: self.content = image self.type = 'image' else: for key, value in kwargs.items(): content_type = self.content_type(value) if content_type: self.type = content_type self.content = value break if self.type is None: raise ValueError(f"couldn't find valid message content in {kwargs}, please specify its type") # Apply variable substitutions #self.apply_substitutions(kwargs.get('substitutions')) @property def num_tokens(self): """ Return the number of tokens used by this message. embed() needs to have been called for this to be valid. """ if self.tokens is not None: if isinstance(self.tokens, (np.ndarray, torch.Tensor)): return self.tokens.shape[1] elif isinstance(self.tokens, list): return len(self.tokens) else: raise TypeError(f"ChatMessage had tokens with invalid type ({type(self.tokens)})") elif self.embedding is not None: return self.embedding.shape[1] else: return 0 @property def start_token(self): """ The token offset or position in the chat history at which this message begins. """ offset = 0 for i in range(0, self.index): offset += self.history[i].num_tokens return offset
[docs] @staticmethod def content_type(content): """ Try to automatically determine the message content type. """ if isinstance(content, str): if content.endswith(ImageExtensions): return 'image' else: return "text" elif isinstance(content, ImageTypes): return 'image' else: return None
[docs] def is_type(self, type): """ Return true if the message is of the given type (like 'text', 'image', ect) """ return (self.type == type)
''' def apply_substitutions(self, substitutions=None): """ Apply variable substitutions to the message content, like "Today's date is ${DATE}". This is separate from the templating that occurs with the special tokens & separators. """ if self.type != 'text' or self.cached or substitutions is False: return if isinstance(substitutions, dict): for key, value in substitutions.items(): self.content = self.content.replace(key, value) return if "${DATE}" in self.content: self.content = self.content.replace("${DATE}", datetime.date.today().strftime("%Y-%m-%d")) if "${TIME}" in self.content: self.content = self.content.replace("${TIME}", datetime.datetime.now().strftime("%-I:%M %p")) if "${TOOLS}" in self.content: from nano_llm import BotFunctions self.content = self.content.replace("${TOOLS}", BotFunctions.generate_docs(spec=self.history.template.tool_spec)) if "${LOCATION}" in self.content: from nano_llm.plugins.bot_functions.location import LOCATION self.content = self.content.replace("${LOCATION}", LOCATION()) '''
[docs] def embed(self, return_tensors='np', **kwargs): """ Apply message templates, tokenization, and generate the embedding. """ if self.embedding is not None: return self.embedding if self.tokens is not None and not self.history.model.has_embed: if isinstance(self.tokens, list): self.tokens = np.expand_dims(np.asarray(self.tokens, dtype=np.int32), axis=0) return self.tokens if self.history is None: raise RuntimeError("this message needs to be added to a ChatHistory before embed() is called") # lookup the role template to apply first_msg = 1 if 'system' in self.history.template else 0 role = 'first' if 'first' in self.history.template and self.index == first_msg else self.role if role not in self.history.template: raise RuntimeError(f"chat template {self.history.template.get('name', '')} didn't have a role defined for '{entry.role}' (had keys: {self.history.template.keys()})") # extract template prefix/postfix template = self.history.template[role] split_template = template.split('${MESSAGE}') if len(split_template) == 1: # there was no ${MESSAGE} split_template.append('') if self.prev and self.prev.role == self.role: split_template[0] = '' if self.next and self.next.role == self.role: split_template[1] = '' # embed based on media type if self.type == 'text': self._embed_text(self.history.model, split_template, return_tensors=return_tensors, **kwargs) elif self.type == 'image': self._embed_image(self.history.model, split_template, return_tensors=return_tensors, **kwargs) # mark as cached self.cached = True if self.embedding is not None: return self.embedding if self.tokens is not None: return self.tokens
def _embed_text(self, model, template, return_tensors='np', **kwargs): """ Generate the token embeddings for a text message. """ self.template = template[0] + self.content + template[1] if self.tokens is not None: if model.has_embed: self.embedding = model.embed_tokens(self.tokens, return_tensors=return_tensors, **kwargs) else: self.embedding, self.tokens = model.embed_text( self.template, use_cache=self.use_cache, return_tensors=return_tensors, return_tokens=True, **kwargs ) def _embed_image(self, model, template, return_tensors='np', **kwargs): """ Generate the encoded vision embeddings for an image. """ if not model.has_vision: raise RuntimeError(f"attempted to embed an image in the chat, but '{model.config.name}' was not a multimodal vision model") # add the template prefix embeddings = [] if template[0]: embeddings.append(model.embed_text(template[0], use_cache=True, return_tensors=return_tensors)) # encode the image image_outputs = model.embed_image(self.content, return_tensors=return_tensors, return_dict=True) self.history.image_embedding = image_outputs.image_embeds # save the unprojected embeddings for RAG embeddings.append(image_outputs.embedding) # add the template trailer template[1] = '\n' + template[1] if template[1]: embeddings.append(model.embed_text(template[1], use_cache=True, return_tensors=return_tensors)) # concatenate all embeddings self.embedding = np.concatenate(embeddings, axis=1) if self.history.print_stats: print_table(model.vision.stats)
#logging.debug(f"chat embed image shape={self.embedding.shape} dtype={self.embedding.dtype} template={template}")