Source code for nano_llm.chat.stream

#!/usr/bin/env python3
import threading

import torch
import numpy as np

from nano_llm.utils import ends_with_token


[docs] class StreamingResponse(): """ Asynchronous output iterator returned from :meth:`NanoLLM.generate`. Use it to stream the reply from the LLM as they are decoded token-by-token:: response = model.generate("Once upon a time,") for token in response: print(token, end='', flush=True) The entire response generated so far is also stored in :attr:`StreamingResponse.tokens` and :attr:`StreamingResponse.text`. To terminate processing prematurely, call :meth:`StreamingResponse.stop`, which will signal the model to stop from generating additional output tokens. """ def __init__(self, model, input, **kwargs): super().__init__() #: accumulated output tokens generated so far (for the whole reply) self.tokens = [] #: detokenized output text generated so far (for the whole reply) self.text = '' #: the new text added since the iterator was last read self.delta = '' #: the original input query from the user self.input = input #: the :class:`NanoLLM` model instance being used to generate the output self.model = model #: the :class:`KVCache` used by this request self.kv_cache = kwargs.get('kv_cache', None) #: set if the user requested early termination self.stopping = False #: set when generation has actually stopped self.stopped = False self.event = threading.Event() self.kwargs = kwargs def __iter__(self): return self
[docs] def __next__(self): """ Wait until the model generates more output, and return the new text (only the delta) """ if self.stopped: ''' # early-stop EOS token is now added inside LLM APIs stop_tokens = self.kwargs.get('stop_tokens', [self.model.tokenizer.eos_token_id]) if not ends_with_token(self.tokens, stop_tokens, self.model.tokenizer): self.add_tokens(self.model.tokenizer.eos_token_id) # add EOS if necessary return self._pop_delta() ''' delta = self._pop_delta() if delta: return delta else: raise StopIteration self.event.wait() self.event.clear() return self._pop_delta()
@property def eos(self): """ Returns true if End of Sequence (EOS) and generation has stopped. """ return self.stopped
[docs] def stop(self): """ Signal the model to halt output generation before the end of the reply. """ self.stopping = True
[docs] def add_tokens(self, tokens, detokenize=True, event=False): """ Add an output token, detokenize the reply, and accumulate the delta message. This function is only used by the model APIs when they generate a new token. """ if isinstance(tokens, (torch.Tensor, np.ndarray)): tokens = tokens.squeeze().tolist() if isinstance(tokens, list): self.tokens.extend(tokens) elif tokens is not None: self.tokens.append(tokens) if not detokenize: return # detokenize the entire reply on each new output token, because multiple tokens can # combine with each other, changing the previous text (like with long words and unicode) message = self.model.tokenizer.decode(self.tokens, skip_special_tokens=False, clean_up_tokenization_spaces=False) self.delta = self.delta + message[len(self.text):] self.text = message if event: self.event.set()
def _pop_delta(self, reset=True): """ Get the tokens that have accumulated since the iterator was last read, and reset it. """ delta = self.delta if reset: self.delta = '' return delta