Source code for clip_trt.utils.tensor

#!/usr/bin/env python3
import torch
import numpy as np


__all__ = [
    'cudaArrayInterface', 'torch_dtype_dict', 'torch_dtype', 
    'convert_dtype', 'convert_tensor', 'is_embedding'
]


class cudaArrayInterface():
    """
    Exposes __cuda_array_interface__ - typically used as a temporary view into a larger buffer
    https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html
    """
    def __init__(self, data, shape, dtype=np.float32):
        if dtype == np.float32:
            typestr = 'f4'
        elif dtype == np.float64:
            typestr = 'f8'
        elif dtype == np.float16:
            typestr = 'f2'
        else:
            raise RuntimeError(f"unsupported dtype:  {dtype}")
            
        self.__cuda_array_interface__ = {
            'data': (data, False),  # R/W
            'shape': shape,
            'typestr': typestr,
            'version': 3,
        }  

        
torch_dtype_dict = {
    'bool'       : torch.bool,
    'uint8'      : torch.uint8,
    'int8'       : torch.int8,
    'int16'      : torch.int16,
    'int32'      : torch.int32,
    'int64'      : torch.int64,
    'float16'    : torch.float16,
    'float32'    : torch.float32,
    'float64'    : torch.float64,
    'complex64'  : torch.complex64,
    'complex128' : torch.complex128
}


def torch_dtype(dtype):
    """
    Convert numpy.dtype or str to torch.dtype
    """
    if isinstance(dtype, torch.dtype):
        return dtype
    elif not isinstance(dtype, type):
        # from np.dtype() (not a built-in np.float32, ect)
        torch_dtype = torch_dtype_dict.get(str(dtype))
        
        if torch_dtype is None:
            raise ValueError("unknown dtype {dtype}  (type={type(dtype)}")
            
        return torch_dtype

    if dtype == np.float32:      return torch.float32
    elif dtype == np.float64:    return torch.float64
    elif dtype == np.int8:       return torch.int8
    elif dtype == np.int16:      return torch.int16
    elif dtype == np.int32:      return torch.int32
    elif dtype == np.int64:      return torch.int64
    elif dtype == np.uint8:      return torch.uint8
    elif dtype == np.uint16:     return torch.uint16
    elif dtype == np.uint32:     return torch.uint32
    elif dtype == np.uint64:     return torch.uint64
    elif dtype == np.complex64:  return torch.complex64
    elif dtype == np.complex128: return torch.complex128
    elif dtype == np.bool_:      return torch.bool
    
    raise ValueError("unknown dtype {dtype}  (type={type(dtype)}")


[docs] def convert_dtype(dtype, to='np'): """ Convert a string, numpy type, or torch.dtype to either numpy or PyTorch """ if dtype is None: return None if to == 'pt': return torch_dtype(dtype) elif to == 'np': if isinstance(dtype, type): return dtype elif isinstance(dtype, torch.dtype): return np.dtype(str(dtype).split('.')[-1]) # remove the torch.* prefix else: return np.dtype(dtype) raise TypeError(f"expected dtype as a string, type, or torch.dtype (was {type(dtype)}) and with to='np' or to='pt' (was {to})")
[docs] def convert_tensor(tensor, return_tensors='pt', device=None, dtype=None, **kwargs): """ Convert tensors between numpy/torch/ect """ if tensor is None: return None dtype = convert_dtype(dtype, to=return_tensors) if isinstance(tensor, np.ndarray): if return_tensors == 'np': # np->np if dtype: tensor = tensor.astype(dtype=convert_dtype(dtype, to='np'), copy=False) return tensor elif return_tensors == 'pt': # np->pt return torch.from_numpy(tensor).to(device=device, dtype=convert_dtype(dtype, to='pt'), **kwargs) elif isinstance(tensor, torch.Tensor): if return_tensors == 'np': # pt->np if dtype: tensor = tensor.type(dtype=convert_dtype(dtype, to='pt')) return tensor.detach().cpu().numpy() elif return_tensors == 'pt': # pt->pt if device is not None or dtype is not None: return tensor.to(device=device, dtype=convert_dtype(dtype, to='pt'), **kwargs) else: return tensor elif isinstance(tensor, list): if return_tensors == 'np': return np.asarray(tensor, dtype=dtype) elif return_tensors == 'pt': return torch.as_tensor(tensor, dtype=dtype, device=device) raise ValueError(f"unsupported tensor input/output type (in={type(tensor)} out={return_tensors})")
def is_embedding(tensor): """ Determine if a tensor is likely to be an embedding (torch.Tensor or np.ndarray with dtype=float32) """ if isinstance(tensor, torch.Tensor) and torch.is_floating_point(tensor.dtype): return True elif isinstance(tensor, np.ndarray) and (tensor.dtype == np.float16 or tensor.dtype == np.float32 or tensor.dtype == np.float64): return True else: return False