Models
The NanoLLM
interface provides model loading, quantization, embeddings, and inference.
from nano_llm import NanoLLM
model = NanoLLM.from_pretrained(
"meta-llama/Llama-3-8b-hf", # HuggingFace repo/model name, or path to HF model checkpoint
api='mlc', # supported APIs are: mlc, awq, hf
api_token='hf_abc123def', # HuggingFace API key for authenticated models ($HUGGINGFACE_TOKEN)
quantization='q4f16_ft' # q4f16_ft, q4f16_1, q8f16_0 for MLC, or path to AWQ weights
)
response = model.generate("Once upon a time,", max_new_tokens=128)
for token in response:
print(token, end='', flush=True)
You can run text completion from the command-line like this:
python3 -m nano_llm.completion --api=mlc \
--model meta-llama/Llama-3-8b-chat-hf \
--quantization q4f16_ft \
--prompt 'Once upon a time,'
See the Chat section for examples of running multi-turn chat and function calling.
Supported Architectures
Llama
Llava
StableLM
Phi-2
Gemma
Mistral
GPT-Neox
These include fine-tuned derivatives that share the same network architecture as above (for example, lmsys/vicuna-7b-v1.5
is a Llama model). Others model types are supported via the various quantization APIs well - check the associated library documentation for details.
Tested Models
Access to Gated Models from HuggingFace Hub
To download models requiring authentication, generate an API key and request access (Llama)
Large Language Models
Small Language Models (SLM)
Vision Language Models (VLM)
Model API
- class NanoLLM(model_path, **kwargs)[source]
Bases:
object
LLM interface that model APIs implement, including:
generate()
for token generation
The static method
from_pretrained()
will load the model using the specified API.- static from_pretrained(model, api=None, use_cache=False, **kwargs)[source]
Load a model from the given path or download it from HuggingFace Hub. Various inference and quantization APIs are supported, such as MLC and AWQ. If the API isn’t explicitly specified, it will be inferred from the type of model.
Base class for local LLM APIs. It defines common Huggingface-like interfaces for model loading, text generation, tokenization, embeddings, and streaming. It also supports multimodal vision models like Llava and generating image embeddings with CLIP.
- Parameters:
model (str) – either the path to the model, or HuggingFace model repo/name.
api (str) – the model backend API to use: ‘auto_gptq’, ‘awq’, ‘mlc’, or ‘hf’ if left as None, it will attempt to be automatically determined.
quantization (str) – for AWQ or MLC, either specify the quantization method, or the path to the quantized model (AWQ and MLC API’s only)
vision_model (str) – for VLMs, override the vision embedding model (typically openai/clip-vit-large-patch14-336). Otherwise, it will use the CLIP variant from the config.
- Returns:
A loaded NanoLLM model instance using the determined API.
- generate(inputs, streaming=True, **kwargs)[source]
Generate output from input text, tokens, or an embedding. For detailed kwarg descriptions, see transformers.GenerationConfig.
- Parameters:
inputs (str|ndarray) – Text or embedding inputs to the model/
streaming (bool) – If True, an iterator will be returned that returns text chunks. Otherwise, this function will block and return the generated text.
functions (list[callable]) – Dynamic functions or plugins to run inline with token generation for things like function calling, guidance, token healing, ect. These will be passed the text generated by the LLM so far, and any additional text that these return will be added to the chat.
max_new_tokens (int) – The number of tokens to output in addition to the prompt (default: 128)
min_new_tokens (int) – Force the model to generate a set number of output tokens (default: -1)
do_sample (bool) – If
True
, temperature/top_p will be used. Otherwise, greedy search (default:False
)repetition_penalty – The parameter for repetition penalty. 1.0 means no penalty (default: 1.0)
temperature (float) – Randomness token sampling parameter (default=0.7, only used if
do_sample=True
)top_p (float) – If set to float < 1 and
do_sample=True
, only the smallest set of most probable tokens. with probabilities that add up to top_p or higher are kept for generation (default 0.95)stop_tokens (list[int]|list[str]) – Stop generation if the bot produces tokens or text from this list (defaults to EOS token ID)
kv_cache (np.ndarray) – Previous kv_cache that the inputs will be appended to. By default, a blank kv_cache will be created for each generation (i.e. a new chat). This generation’s kv_cache will be set in the returned
StreamingResponse
iterator after the request is complete.
- Returns:
An asynchronous
StreamingResponse
iterator (whenstreaming=True
) that outputs one decoded token string at a time. Otherwise, this function blocks and a string containing the full reply is returned after it’s been completed.
- tokenize(text, add_special_tokens=False, dtype=<class 'numpy.int32'>, return_tensors='np', **kwargs)[source]
Tokenize the given string and return the encoded token ID’s.
- Parameters:
text (str) – the text to tokenize.
add_special_tokens (str) – if BOS/EOS tokens (like
<s>
or<|endoftext|>
) should automatically be added (default False)dtype (type) – the numpy or torch datatype of the tensor to return.
return_tensors (str) –
'np'
to return a np.ndarray or'pt'
to return a torch.Tensorkwargs – additional arguments forwarded to the HuggingFace transformers.AutoTokenizer encode function.
- Returns:
The token ID’s with the tensor type as indicated by return_tensors (either ‘np’ for np.ndarray or ‘pt’ for torch.Tensor) and datatype as indicated by dtype (by default
int32
)
- detokenize(tokens, skip_special_tokens=False, **kwargs) str [source]
Detokenize the given token ID’s and return the decoded string.
- Parameters:
tokens (list[int], np.ndarray, torch.Tensor) – the array of token ID’s
skip_special_tokens (bool) – if special tokens (like BOS/EOS) should be supressed from the output or not (default false)
kwargs –
additional arguments forwarded to the HuggingFace transformers.AutoTokenizer decode function.
- Returns:
The string containing the decoded text.
- embed_text(text, add_special_tokens=False, use_cache=False, return_tensors='np', return_tokens=False, **kwargs)[source]
Tokenize the string with
NanoLLM.tokenize()
and return its embedding as computed byNanoLLM.embed_tokens()
. Note that ifmodel.has_embed=False
, then None will be returned for the embedding and the tokens should be used instead.- Parameters:
text (str) – the text to tokenize and embed.
add_special_tokens (str) – if BOS/EOS tokens (like
<s>
,<|endoftext|>
) should automatically be added (default False)use_cache (bool) – if True, the text embedding will be cached and returned without additional computation if the same string was already embedded previously. This is useful for things like the system prompt that are relatively static, but probably shouldn’t be used for dynamic user inputs that are unlikely to be re-used again (leading to unnecessarily increased memory usage). The default is false.
return_tensors (str) –
'np'
to return a np.ndarray or'pt'
to return a torch.Tensorreturn_tokens (bool) – if True, then the tokens will also be returned in addition to the embedding.
kwargs –
additional arguments forwarded to
NanoLLM.tokenize()
and the HuggingFace transformers.AutoTokenizer
- Returns:
The embedding with the tensor type as indicated by return_tensors (either ‘np’ for np.ndarray or ‘pt’ for torch.Tensor) with
float32
data. Ifreturn_tokens=True
, then an (embedding, tokens) tuple will be returned instead of only the embeddings. If ``model.has_embed=False`, then the embedding will be None.
- embed_tokens(tokens, return_tensors='np', **kwargs)[source]
Compute the token embedding and return its tensor. This will raise an exception if
model.has_embed=False
.- Parameters:
tokens (list[int], np.ndarray, torch.Tensor) – the array of token ID’s
return_tensors (str) –
'np'
to return a np.ndarray or'pt'
to return a torch.Tensor
- Returns:
The embedding with the tensor type as indicated by return_tensors (either ‘np’ for np.ndarray or ‘pt’ for torch.Tensor) with
float32
data.
- embed_image(image, return_tensors='pt', return_dict=False, **kwargs)[source]
Compute the embedding of an image (for multimodel models with a vision encoder like CLIP), and apply any additional projection layers as specified by the model.
- Parameters:
image (pil.Image, np.ndarray, torch.Tensor, jetson.utils.cudaImage, __cuda_array_interface__) – the image
return_tensors (str) –
'np'
to return a np.ndarray or'pt'
to return a torch.Tensor (on the GPU)return_dict (bool) – if true, return a dict including the vision encoder’s hidden_state and embedding
kwargs – additional arguments forwarded to the vision encoder (nano_llm.vision.CLIPImageEmbedding)
- Returns:
The embedding with the tensor type as indicated by return_tensors (either ‘np’ for np.ndarray or ‘pt’ for torch.Tensor), or a dict containing the embedding and vision encoder’s hidden_state if
return_dict=True
.
- config_path
The local path to the model config file (
config.json
)
- model_path
The local path to the model checkpoint/weights in HuggingFace format.
- config
Dict containing the model configuration (inspect it on the HuggingFace model card)
- stats
Dict containing the latest generation performance statistics.
- has_vision
True if this is a multimodal vision/language model.
- has_embed
True if this model has a separate text embedding layer for embed_text()
- tokenizer
HuggingFace transformers.AutoTokenizer instance used for tokenization/detokenization.
Streaming
- class StreamingResponse(model, input, **kwargs)[source]
Bases:
object
Asynchronous output iterator returned from
NanoLLM.generate()
. Use it to stream the reply from the LLM as they are decoded token-by-token:response = model.generate("Once upon a time,") for token in response: print(token, end='', flush=True)
The entire response generated so far is also stored in
StreamingResponse.tokens
andStreamingResponse.text
. To terminate processing prematurely, callStreamingResponse.stop()
, which will signal the model to stop from generating additional output tokens.- tokens
accumulated output tokens generated so far (for the whole reply)
- text
detokenized output text generated so far (for the whole reply)
- delta
the new text added since the iterator was last read
- input
the original input query from the user
- stopping
set if the user requested early termination
- stopped
set when generation has actually stopped
- __next__()[source]
Wait until the model generates more output, and return the new text (only the delta)
- property eos
Returns true if End of Sequence (EOS) and generation has stopped.
KV Cache
- class KVCache[source]
Bases:
object
Abstract interface for storing & manipulating the KV cache, which encodes all the context and model state in the chat. These are implemented by different LLM backends and are backed by CUDA memory for each layer in the model, which these functions provide some modifications.
It gets returned in the
StreamingResponse
iterator fromNanoLLM.generate()
and as an optional argument can be re-used during the next generation to grow the cache instead of having to refill the chat context each request.For example,
KVCache.pop()
will drop the most recent N tokens off the end of the cache, whileKVCache.remove()
while remove a range of tokens from anywhere in the cache.The
ChatHistory
object provides a higher-level way of maintaining consistency for removing messages from the chat by keeping track of their token counts and positions in the chat. It also keeps the KV cache between requests, so that only the new tokens need to be added (and the model only processes those).- num_tokens
The current length of the KV cache