# -*- coding: utf-8 -*-
"""Provides utility functions for visualization."""
from warnings import warn
from math import ceil
import matplotlib.pyplot as plt
# import mpl_toolkits.axes_grid.axes_size as Size
# from mpl_toolkits.axes_grid import Divider
import numpy as np
[docs]
def plot_image(x, fig=None, ax=None, **kwargs):
"""Plot image using matplotlib's :meth:`imshow` method.
Parameters
----------
x : array-like or PIL image
The image data. For further information see `imshow documentation
<https://matplotlib.org/api/_as_gen/matplotlib.pyplot.imshow.html>`_.
fig : :class:`matplotlib.figure.Figure`, optional
The figure to plot the image in. If ``fig is None``, but `ax` is given,
it is retrieved from `ax`. If both ``fig is None`` and ``ax is None``,
a new figure is created.
ax : :class:`matplotlib.axes.Axes`, optional
The axes to plot the image in. If `None`, an axes object is created
in `fig`.
kwargs : dict, optional
Keyword arguments passed to ``ax.imshow``.
Returns
-------
im : :class:`matplotlib.image.AxesImage`
The image that was plotted.
ax : :class:`matplotlib.axes.Axes`
The axes the image was plotted in.
"""
if fig is None:
if ax is None:
fig = plt.figure()
else:
fig = ax.get_figure()
if ax is None:
ax = fig.add_subplot(111)
kwargs.setdefault('cmap', 'gray')
xticks = kwargs.pop('xticks', None)
yticks = kwargs.pop('yticks', None)
if xticks is not None:
ax.set_xticks(xticks)
if yticks is not None:
ax.set_yticks(yticks)
im = ax.imshow(np.asarray(x).T, **kwargs)
return im, ax
[docs]
def plot_images(x_list, nrows=1, ncols=-1, fig=None, vrange='equal',
cbar='auto', rect=None, fig_size=None, **kwargs):
"""Plot multiple images using matplotlib's :meth:`imshow` method in
subplots.
Parameters
----------
x_list : sequence of (array-like or PIL image)
List of the image data. For further information see `imshow
documentation
<https://matplotlib.org/api/_as_gen/matplotlib.pyplot.imshow.html>`_.
nrows : int, optional
The number of subplot rows (the default is 1). If -1, it is computed by
``ceil(len(x_list)/ncols)``, or set to 1 if `ncols` is not given.
ncols : int, optional
The number of subplot columns. If -1, it is computed by
``ceil(len(x_list)/nrows)`` (default). If both `nrows` and `ncols` are
given, the value of `ncols` is ignored.
vrange : {``'equal'``, ``'individual'``} or [list of ](float, float),\
optional
Value ranges for the colors of the images.
If a string is passed, the range is auto-computed:
``'equal'``
The same colors are used for all images.
``'individual'``
The colors differ between the images.
If a tuple of floats is passed, it is used for all images.
If a list of tuples of floats is passed, each tuple is used for one
image.
cbar : {``'one'``, ``'many'``, ``'auto'``, ``'none'``}, optional
Colorbar option.
If ``cbar=='one'``, one colorbar is shown. Only possible if the value
ranges used for the colors (cf. `vrange`) are the same for all images.
If ``cbar=='many'``, a colorbar is shown for every image.
If ``cbar=='auto'``, either ``'one'`` or ``'many'`` is chosen,
depending on whether `vrange` is equal for all images.
If ``cbar=='none'``, no colorbars are shown.
fig : :class:`matplotlib.figure.Figure`, optional
The figure to plot the images in. If `None`, a new figure is created.
kwargs : dict, optional
Keyword arguments passed to `plot_image`, which in turn passes them to
``imshow``.
Returns
-------
im : ndarray of :class:`matplotlib.image.AxesImage`
The images that were plotted.
ax : ndarray of :class:`matplotlib.axes.Axes`
The axes the images were plotted in.
"""
try:
x_list = list(x_list)
except TypeError:
raise TypeError('x_list must be iterable. Pass a sequence or use '
'`plot_image` to plot single images.')
for i in range(len(x_list)):
x_list[i] = np.asarray(x_list[i])
if fig is None:
fig = plt.figure()
if nrows is None or nrows == -1:
if ncols is None or ncols == -1:
nrows = 1
else:
nrows = ceil(len(x_list)/ncols)
ncols = ceil(len(x_list)/nrows)
if rect is None:
rect = [0.1, 0.1, 0.8, 0.8]
if fig_size is not None:
fig.set_size_inches(fig_size)
if isinstance(vrange, str):
if vrange == 'equal':
vrange_ = [(min((np.min(x) for x in x_list)),
max((np.max(x) for x in x_list)))] * len(x_list)
VRANGE_EQUAL = True
elif vrange == 'individual':
vrange_ = [(np.min(x), np.max(x)) for x in x_list]
VRANGE_EQUAL = False
else:
raise ValueError("`vrange` must be 'equal' or 'individual'")
elif isinstance(vrange, tuple) and len(vrange) == 2:
vrange_ = [vrange] * len(x_list)
VRANGE_EQUAL = True
else:
vrange_ = vrange
VRANGE_EQUAL = False
if not VRANGE_EQUAL:
if cbar == 'one':
warn("cannot use cbar='one' when vrange is not equal for all"
"images, falling back to cbar='many'")
if cbar != 'none':
cbar = 'many'
elif cbar == 'auto':
cbar = 'one'
ax = fig.subplots(nrows, ncols)
if isinstance(ax, plt.Axes):
ax = np.atleast_1d(ax)
im = np.empty(ax.shape, dtype=object)
for i, (x, ax_, v) in enumerate(zip(x_list, ax.flat, vrange_)):
im_, _ = plot_image(x, ax=ax_, vmin=v[0], vmax=v[1], **kwargs)
im.flat[i] = im_
if cbar == 'many':
fig.colorbar(im_, ax=ax_)
if cbar == 'one':
fig.colorbar(im[0], ax=ax)
return im, ax