#!/usr/bin/env python3
import os
import re
import json
import logging
import termcolor
import numpy as np
from .message import ChatMessage
from .stream import StreamingResponse
from .templates import ChatTemplate, ChatTemplates, StopTokens
from ..utils import AttributeDict, escape_html, code_tags
[docs]
class ChatHistory():
"""
Multimodal chat history that can contain a mix of media including text/images.
ChatHistory objects can be indexed like a list to access its messages,
where each :class:`ChatMessage` can have a different type of content::
chat_history[n] # will return the n-th chat entry
Each type of media has an associated embedding function (e.g. LLM's typically
do text token embedding internally, and images use CLIP + projection layers).
From these, it assembles the embedding for the entire chat as input to the LLM.
It uses templating to add the required special tokens as defined by different
model architectures. In normal 2-turn chat, there are 'user' and 'bot' roles
defined, but arbitrary roles can be added, each with their own template.
The system prompt can also be configured through the chat template
and by setting the :attr:`ChatHistory.system_prompt` property.
"""
def __init__(self, model, chat_template=None, system_prompt=None, **kwargs):
"""
Parameters:
model (NanoLLM): The model instance used for embeddings
chat_template (str|dict): Either a chat template dict, or the name of the
chat template to use like ``llama-2``, ``vicuna-v1``
If None, will attempt to determine model type.
system_prompt (str): Set the default system prompt used at the beginning of chats.
If ``None``, will use system prompt from the template by default.
tools (bool|str): If True, tool calling will be enabled for models that have the
``tool_call`` and ``tool_response`` roles in their chat templates.
When enabled, the function descriptors will automatically be generated
from their pydoc strings, and appended to the system prompt.
print_stats (bool): If True, generation performance will be printed to the terminal after EOS.
This also gets enabled by default if ``--debug`` or ``--verbose`` is used.
"""
self.model = model
self.messages = None
#: The :class:`KVCache` from :meth:`NanoLLM.generate()` used to store the model state.
self.kv_cache = None
# look-up or load the chat template
if not chat_template or chat_template == 'auto':
self.template = ChatTemplate(model)
if self.template is None:
raise RuntimeError(f"Couldn't automatically determine model type from {model.config.name}, please set the --chat-template argument")
logging.info(f"using chat template '{self.template.name}' for model {model.config.name}")
elif isinstance(chat_template, str):
if os.path.isfile(chat_template):
with open(chat_template) as template_file:
self.template = AttributeDict(json.load(template_file))
else:
self.template = AttributeDict(ChatTemplates[chat_template])
elif isinstance(chat_template, dict):
self.template = AttributeDict(template)
else:
raise TypeError(f"chat_template should be a str or dict (was {type(chat_template)})")
# parse the stop tokens
if 'stop' in self.template:
if not isinstance(self.template.stop, list):
self.template.stop = [self.template.stop]
for i, stop in enumerate(self.template.stop):
if isinstance(stop, str):
self.template.stop[i] = self.model.tokenizer(stop, add_special_tokens=False, return_tensors='np').input_ids.squeeze().tolist()
else:
self.template.stop = [self.model.tokenizer.eos_token_id]
#self.template.stop = [x for x in self.template.stop if x >= 0] # filter out ignored stop tokens
logging.info(f"model '{self.model.config.name}', chat template '{self.template.name}' stop tokens: {self.model.tokenizer.batch_decode(self.template.stop)} -> {self.template.stop}")
# setup the default system prompt
if system_prompt:
self.template['system_prompt'] = system_prompt
# try to determine the function-calling style
if 'tool_spec' not in self.template:
if 'tool_call' in self.template:
self.template.tool_spec = 'openai'
else:
self.template.tool_spec = kwargs.get('tool_spec')
self.print_stats = kwargs.get('print_stats', kwargs.get('debug', False))
self.web_regex = [
(re.compile(r'`(.*?)`'), r'<code>\1</code>'), # code blocks
(re.compile(r'\*(.*?)\*'), r'*<i>\1</i>*'), # emotives inside asterisks
]
from nano_llm import BotFunctions
self.BotFunctions = BotFunctions
self.reset()
@property
def num_tokens(self):
"""
Return the number of tokens used by the chat so far.
:meth:`embed_chat()` needs to have been called for this to be upated,
because otherwise the input wouldn't have been tokenized yet.
"""
position = 0
for msg in self.messages:
position += msg.num_tokens
return position
[docs]
def __len__(self):
"""
Returns the number of messages in the chat history
"""
return len(self.messages)
[docs]
def __getitem__(self, key):
"""
Return the n-th chat message with the subscript indexing operator
"""
return self.messages[key]
[docs]
def __delitem__(self, key):
"""
Remove one or more messages from the chat history::
del chat_history[-2] # remove the second-to-last entry
del chat_history[-2:] # pop the last 2 entries
del chat_history[1:] # remove all entries but the first
This will also update the KV cache and alter the bot memory.
"""
if isinstance(key, int):
start = key
stop = key + 1
elif isinstance(key, ChatMessage):
start = self.messages.index(key)
stop = start + 1
elif isinstance(key, slice):
start = key.start
stop = key.stop
else:
raise TypeError(f"The `del chat_history[*]` operator expects an int, ChatMessage, or slice (was '{type(key)}')")
if start is None:
start = 0
if stop is None:
stop = len(self.messages)
self.remove(start, stop)
[docs]
def append(self, role='user', msg=None, **kwargs):
"""
Add a chat entry consisting of a text message, image, ect.
See the :class:`ChatMessage` class for description of arguments.
This can also accept an existing :class:`ChatMessage` set to ``msg``.
"""
if isinstance(msg, ChatMessage):
self.messages.append(msg)
elif isinstance(msg, StreamingResponse):
self.messages.append(ChatMessage(role, text=msg.text, tokens=msg.tokens, history=self, **kwargs))
self.kv_cache = msg.kv_cache
else:
self.messages.append(ChatMessage(role, msg=msg, history=self, **kwargs))
self.reindex()
return self.messages[-1]
[docs]
def pop(self, count):
"""
Remove the last N messages from the chat and KV cache.
"""
num_tokens = 0
for n in range(0, count):
num_tokens += self.messages[len(self.messages)-n-1].num_tokens
if self.kv_cache:
self.kv_cache.pop(num_tokens)
del self.messages[-count:]
self.reindex()
[docs]
def remove(self, start, stop=None):
"""
Remove the chat entries from the start (inclusive) to stop (exclusive) indexes.
If stop is not specified, then only the single entry at the start index will be removed::
chat_history.remove(0) # remove the first chat entry
chat_history.remove(0,2) # remove the first and second chat entries
chat_history.remove(-1) # remove the last chat entry
chat_history.remove(-2,0) # remove the last two entries
This will also update the KV cache and alter the bot's memory (potentially destructively)
"""
num_messages = len(self.messages)
if stop is None:
stop = start + 1
if start < 0:
start += num_messages
if stop <= 0:
stop += num_messages
if stop > num_messages:
raise ValueError(f"remove index {stop} exceeded the number of messages ({num_messages})")
if stop == num_messages:
return self.pop(num_messages - start)
if self.kv_cache:
self.kv_cache.remove(self.messages[start].start_token, self.messages[stop].start_token)
del self.messages[start:stop]
self.reindex()
[docs]
def reset(self, system_prompt=True, use_cache=True, wrap_tokens=None):
"""
Reset the chat history, and optionally add the system prompt to the new chat.
If ``use_cache=True``, then the system prompt tokens/embedding will be cached.
If `wrap_tokens` is set, then the most recent N tokens from the chat will be kept.
"""
if wrap_tokens:
wrap_entry = self.find_wrap_entry(wrap_tokens)
if wrap_entry:
logging.warning(f"Wrapping chat to keep the most recent {len(self.messages)-wrap_entry} messages")
self.messages = self.messages[wrap_entry:]
else:
logging.warning(f"Chat history overflow couldn't find previous chat entry to wrap to (clearing chat)")
self.messages = []
else:
self.messages = []
self.kv_cache = None
self.image_embedding = None
if isinstance(system_prompt, str):
self.add_system_prompt(system_prompt=system_prompt, use_cache=use_cache)
elif system_prompt:
self.add_system_prompt(use_cache=use_cache)
[docs]
def turn(self, role='user'):
"""
Returns true if it's the given role's turn in the chat, otherwise false.
"""
n = len(self.messages)
prev_role = self.messages[n-1].role if n > 0 else None
if role == 'system':
return (n == 0)
elif role == 'user':
if n == 0:
return ('system' not in self.template)
else:
return (prev_role != 'tool_response')
elif role == 'bot':
return (prev_role == 'user' or prev_role == 'tool_response')
else:
logging.warning(f"unrecognized role in ChatHistory.turn() (role={role})")
return True
[docs]
def to_list(self, messages=None, html=False):
"""
Serialize the history to a list of dicts, where each dict is a chat entry
with the non-critical keys removed (suitable for web transport, ect)
"""
if messages is None:
messages = self.messages
if messages and isinstance(messages[0], ChatMessage):
messages = [{'role' : msg.role, msg.type : msg.content} for msg in messages]
if html:
messages = self.to_html(messages)
return messages
[docs]
def add_system_prompt(self, system_prompt=None, use_cache=True):
"""
Add the system prompt message to the chat, containing :attr:`ChatHistory.system_prompt`
appended by the tool function descriptions if tools are enabled. If the ``system`` role
is not defined by the model's chat template, then this function does nothing.
Arguments:
use_cache (bool): If true, then the system prompt tokens/embeddedings will be cached.
This is the default because the system prompt typically may not change.
Returns:
The :class:`ChatMessage` that was added to the chat with the ``system`` role.
"""
if 'system' not in self.template:
return None
if system_prompt is not None:
self.template.system_prompt = system_prompt
return self.append(role='system', text=self.template.system_prompt, use_cache=use_cache)
@property
def system_prompt(self):
"""
Get the system prompt, the typically hidden instruction at the beginning
of the chat like "You are a curious and helpful AI assistant, ..."
"""
return self.template.get('system_prompt', '')
@system_prompt.setter
def system_prompt(self, instruction):
"""
Set the system prompt instruction string and reset the chat history.
TODO make it so this doesn't reset the chat history, but uncaches it.
"""
if instruction is None:
return
if self.template['system_prompt'] == instruction:
return
self.reset(system_prompt=instruction)
[docs]
def embed_chat(self, use_cache=True, max_tokens=None, wrap_tokens=None, **kwargs):
"""
Assemble the embedding of either the latest or entire chat.
If ``use_cache=True`` (the default), and only the new embeddings will be returned.
If ``use_cache=False``, then the entire chat history will be returned.
This function returns an ``(embedding, position)`` tuple, where the embedding array
contains the new embeddings (or tokens) from the chat, and position is the current
overall position in the history (up to the model's context window length)
If the number of tokens in the chat history exceeds the length given in ``max_tokens`` argument
(which is typically the model's context window, minus the max generation length),
then the chat history will drop all but the latest ``wrap_tokens``, starting with a user prompt.
If `max_tokens` is provided but `wrap_tokens` is not, then the overflow tokens will be truncated.
"""
embeddings = []
position = 0
for n, msg in enumerate(self.messages):
if use_cache:
if msg.cached:
position += msg.num_tokens
else:
embeddings.append(msg.embed())
use_cache = False # all entries after this need to be included
else:
embeddings.append(msg.embed())
#if not use_cache and logging.getLogger().isEnabledFor(logging.DEBUG):
# logging.debug(f"chat msg {n} role={msg.role} type={msg.type} tokens={msg.num_tokens} `{msg.template if msg.template else msg.content if isinstance(msg.content, str) else ''}`".replace('\n', '\\n'))
entries = len(embeddings)
embeddings = np.concatenate(embeddings, axis=1) #, position
'''
if max_tokens and position + embeddings.shape[1] > max_tokens:
if wrap_tokens:
self.reset(wrap_tokens=wrap_tokens)
embeddings, position = self.embed_chat(use_cache=False, max_tokens=max_tokens, wrap_tokens=wrap_tokens, **kwargs)
logging.warning(f"Chat overflow, max history lenth {max_tokens} tokens exceeded (keeping the most recent {embeddings.shape[1]} tokens)")
else:
logging.warning(f"Truncating chat history overflow to {max_tokens} tokens")
return embeddings[:,:max_tokens,:], position
'''
logging.debug(f"chat embed entries={entries} shape={embeddings.shape} position={position}")
return embeddings, position
[docs]
def reindex(self):
"""
Update the linked lists in the messages that refer to each other.
This gets called after messages are added, removed, or their order changed.
You wouldn't typically need to call this yourself.
"""
for i, msg in enumerate(self.messages):
msg.index = i
msg.history = self
if i == 0:
msg.prev = None
elif i > 0:
msg.prev = self.messages[i-1]
msg.prev.next = msg
if i >= len(self.messages) - 1:
msg.next = None
[docs]
def find_wrap_entry(self, wrap_tokens):
"""
Find the oldest entry from which the chat doesn't exceed the number of wrap_tokens,
and that the entry should be a user query. This is used to keep those more recent
chat entries when the history overflows past the max context window of the model.
"""
position = 0
for n in range(len(self.messages)-1, -1, -1):
msg = self.messages[n]
position += msg.num_tokens
if position >= wrap_tokens:
for i in range(n+1, len(self.messages)):
if self.messages[i].role == 'user':
return i
[docs]
def to_html(self, messages=None):
"""
Sanitize message contents to HTML representation, apply code formatting, ect.
"""
messages = self.to_list(messages, html=False)
def web_text(text):
for stop_token in StopTokens:
text = text.replace(stop_token, '')
text = text.strip()
text = text.strip('\n')
if text.find('<tool_call>') == 0:
text = text.replace('\n', '')
text = text.replace('<s>', '')
text = escape_html(text)
for regex, replace in self.web_regex:
text = regex.sub(replace, text)
return code_tags(text)
def web_image(image):
from nano_llm.web import WebServer
if not isinstance(image, str):
if not hasattr(image, 'filename'):
return None
image = image.filename
if WebServer.Instance:
return os.path.join(self.server.mounts.get(os.path.dirname(image), ''), os.path.basename(image))
else:
return image
for entry in messages:
if 'text' in entry:
entry['text'] = web_text(entry['text'])
if 'image' in entry:
entry['image'] = web_image(entry['image'])
if not entry['image']:
del entry['image']
return messages
[docs]
def run_tools(self, message, tools={}, append=True):
"""
Invoke any function calls in the output text and return the results.
"""
if not tools:
return None
if isinstance(message, ChatMessage):
text = message.content if message.is_type('text') else None
elif isinstance(message, dict):
text = message.get('text')
elif isinstance(message, str):
text = message
else:
raise ValueError("expected a message dict or string (was {type(message)})")
if not text:
return None
tool_response = self.BotFunctions.run(text, template=self.template, functions=tools)
if not tool_response:
return None
if append:
self.append('tool_response', tool_response)
return tool_response