#!/usr/bin/env python3
import io
import PIL
import logging
import torch
import torchvision.transforms.functional as F
import numpy as np
ImageTypes = (PIL.Image.Image, np.ndarray, torch.Tensor)
ImageExtensions = ('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')
try:
from jetson_utils import cudaImage, cudaFromNumpy
except:
HAS_JETSON_UTILS=False
else:
HAS_JETSON_UTILS=True
ImageTypes = (*ImageTypes, cudaImage)
__all__ = [
'ImageTypes', 'ImageExtensions', 'is_image', 'image_size',
'load_image', 'cuda_image', 'torch_image', 'torch_image_format',
]
[docs]
def is_image(image):
"""
Returns true if the object is a PIL.Image, np.ndarray, torch.Tensor, or jetson_utils.cudaImage
"""
return isinstance(image, ImageTypes)
[docs]
def image_size(image):
"""
Returns the dimensions of the image as a ``(height, width, channels)`` tuple.
"""
if HAS_JETSON_UTILS and isinstance(image, cudaImage):
return image.shape
if isinstance(image, (np.ndarray, torch.Tensor)):
return image.shape
elif isinstance(image, PIL.Image.Image):
return image.size
else:
raise TypeError(f"expected an image of type {ImageTypes} (was {type(image)})")
[docs]
def load_image(path):
"""
Load an image from a local path or URL that will be downloaded.
Args:
path (str): either a path or URL to the image.
Returns:
``PIL.Image`` instance
"""
if path.startswith('http') or path.startswith('https'):
logging.debug(f'downloading {path}')
response = requests.get(path)
image = PIL.Image.open(io.BytesIO(response.content)).convert('RGB')
else:
logging.debug(f'loading {path}')
image = PIL.Image.open(path).convert('RGB')
return image
[docs]
def cuda_image(image):
"""
Convert an image from `PIL.Image`, `np.ndarray`, `torch.Tensor`, or `__gpu_array_interface__`
to a jetson_utils.cudaImage on the GPU (without using memory copies when possible)
"""
if not HAS_JETSON_UTILS:
raise RuntimeError(f"jetson-utils should be installed to use cudaImage")
# TODO implement __gpu_array_interface__
# TODO torch image formats https://github.com/dusty-nv/jetson-utils/blob/f0bff5c502f9ac6b10aa2912f1324797df94bc2d/python/examples/cuda-from-pytorch.py#L47
if not is_image(image):
raise TypeError(f"expected an image of type {ImageTypes} (was {type(image)})")
if isinstance(image, cudaImage):
return image
if isinstance(image, PIL.Image.Image):
image = np.asarray(image) # no copy
if isinstance(image, np.ndarray):
return cudaFromNumpy(image)
if isinstance(image, torch.Tensor):
input = input.to(memory_format=torch.channels_last) # or tensor.permute(0, 3, 2, 1)
return cudaImage(
ptr=input.data_ptr(),
width=input.shape[-1],
height=input.shape[-2],
format=torch_image_format(input)
)
[docs]
def torch_image(image, dtype=None, device=None):
"""
Convert the image to a type that is compatible with PyTorch ``(torch.Tensor, ndarray, PIL.Image)``
"""
if not isinstance(image, ImageTypes):
raise TypeError(f"expected an image of type {ImageTypes} (was {type(image)})")
if HAS_JETSON_UTILS and isinstance(image, cudaImage):
image = torch.as_tensor(image, dtype=dtype, device=device).permute(2,0,1)
if dtype == torch.float16 or dtype == torch.float32:
image = image / 255.0
elif isinstance(image, (PIL.Image.Image, np.ndarray)):
image = F.to_tensor(image)
return image.to(dtype=dtype, device=device)
def torch_image_format(tensor):
"""
Determine the cudaImage format string (eg 'rgb32f', 'rgba32f', ect) from a PyTorch tensor.
Only float and uint8 tensors are supported because those datatypes are supported by cudaImage.
"""
if tensor.dtype != torch.float32 and tensor.dtype != torch.uint8:
raise ValueError(f"PyTorch tensor datatype should be torch.float32 or torch.uint8 (was {tensor.dtype})")
if len(tensor.shape)>= 4: # NCHW layout
channels = tensor.shape[1]
elif len(tensor.shape) == 3: # CHW layout
channels = tensor.shape[0]
elif len(tensor.shape) == 2: # HW layout
channels = 1
else:
raise ValueError(f"PyTorch tensor should have at least 2 image dimensions (has {tensor.shape.length})")
if channels == 1: return 'gray32f' if tensor.dtype == torch.float32 else 'gray8'
elif channels == 3: return 'rgb32f' if tensor.dtype == torch.float32 else 'rgb8'
elif channels == 4: return 'rgba32f' if tensor.dtype == torch.float32 else 'rgba8'
raise ValueError(f"PyTorch tensor should have 1, 3, or 4 image channels (has {channels})")