#!/usr/bin/env python3
import io
import PIL
import logging
import torch
import torchvision.transforms.functional as F
import numpy as np
ImageTypes = (PIL.Image.Image, np.ndarray, torch.Tensor)
ImageExtensions = ('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')
try:
    from jetson_utils import cudaImage, cudaFromNumpy
except:
    HAS_JETSON_UTILS=False
else:
    HAS_JETSON_UTILS=True   
    ImageTypes = (*ImageTypes, cudaImage)
__all__ = [
    'ImageTypes', 'ImageExtensions', 'is_image', 'image_size',
     'load_image', 'cuda_image', 'torch_image', 'torch_image_format', 
]
[docs]
def is_image(image):
    """
    Returns true if the object is a PIL.Image, np.ndarray, torch.Tensor, or jetson_utils.cudaImage
    """
    return isinstance(image, ImageTypes) 
    
 
[docs]
def image_size(image):
    """
    Returns the dimensions of the image as a ``(height, width, channels)`` tuple.
    """
    if HAS_JETSON_UTILS and isinstance(image, cudaImage):
        return image.shape
    if isinstance(image, (np.ndarray, torch.Tensor)):
        return image.shape
    elif isinstance(image, PIL.Image.Image):
        return image.size
    else:
        raise TypeError(f"expected an image of type {ImageTypes} (was {type(image)})") 
        
    
[docs]
def load_image(path):
    """
    Load an image from a local path or URL that will be downloaded.
    
    Args:
      path (str): either a path or URL to the image.
      
    Returns:
      ``PIL.Image`` instance
    """
    if path.startswith('http') or path.startswith('https'):
        logging.debug(f'downloading {path}')
        response = requests.get(path)
        image = PIL.Image.open(io.BytesIO(response.content)).convert('RGB')
    else:
        logging.debug(f'loading {path}')
        image = PIL.Image.open(path).convert('RGB')
        
    return image 
[docs]
def cuda_image(image):
    """
    Convert an image from `PIL.Image`, `np.ndarray`, `torch.Tensor`, or `__gpu_array_interface__`
    to a jetson_utils.cudaImage on the GPU (without using memory copies when possible)
    """   
    if not HAS_JETSON_UTILS:
        raise RuntimeError(f"jetson-utils should be installed to use cudaImage")
        
    # TODO implement __gpu_array_interface__
    # TODO torch image formats https://github.com/dusty-nv/jetson-utils/blob/f0bff5c502f9ac6b10aa2912f1324797df94bc2d/python/examples/cuda-from-pytorch.py#L47
    if not is_image(image):
        raise TypeError(f"expected an image of type {ImageTypes} (was {type(image)})")
        
    if isinstance(image, cudaImage):
        return image
        
    if isinstance(image, PIL.Image.Image):
        image = np.asarray(image)  # no copy
        
    if isinstance(image, np.ndarray):
        return cudaFromNumpy(image)
        
    if isinstance(image, torch.Tensor):
        input = input.to(memory_format=torch.channels_last)   # or tensor.permute(0, 3, 2, 1)
        
        return cudaImage(
            ptr=input.data_ptr(), 
            width=input.shape[-1], 
            height=input.shape[-2], 
            format=torch_image_format(input)
        ) 
        
 
[docs]
def torch_image(image, dtype=None, device=None):
    """
    Convert the image to a type that is compatible with PyTorch ``(torch.Tensor, ndarray, PIL.Image)``
    """
    if not isinstance(image, ImageTypes):
        raise TypeError(f"expected an image of type {ImageTypes} (was {type(image)})")
        
    if HAS_JETSON_UTILS and isinstance(image, cudaImage):
        image = torch.as_tensor(image, dtype=dtype, device=device).permute(2,0,1)
        if dtype == torch.float16 or dtype == torch.float32:
            image = image / 255.0
    elif isinstance(image, (PIL.Image.Image, np.ndarray)):
        image = F.to_tensor(image)
    return image.to(dtype=dtype, device=device) 
        
        
def torch_image_format(tensor):
    """
    Determine the cudaImage format string (eg 'rgb32f', 'rgba32f', ect) from a PyTorch tensor.
    Only float and uint8 tensors are supported because those datatypes are supported by cudaImage.
    """
    if tensor.dtype != torch.float32 and tensor.dtype != torch.uint8:
        raise ValueError(f"PyTorch tensor datatype should be torch.float32 or torch.uint8 (was {tensor.dtype})")
        
    if len(tensor.shape)>= 4:     # NCHW layout
        channels = tensor.shape[1]
    elif len(tensor.shape) == 3:   # CHW layout
        channels = tensor.shape[0]
    elif len(tensor.shape) == 2:   # HW layout
        channels = 1
    else:
        raise ValueError(f"PyTorch tensor should have at least 2 image dimensions (has {tensor.shape.length})")
        
    if channels == 1:   return 'gray32f' if tensor.dtype == torch.float32 else 'gray8'
    elif channels == 3: return 'rgb32f'  if tensor.dtype == torch.float32 else 'rgb8'
    elif channels == 4: return 'rgba32f' if tensor.dtype == torch.float32 else 'rgba8'
    
    raise ValueError(f"PyTorch tensor should have 1, 3, or 4 image channels (has {channels})")