aboutsummaryrefslogblamecommitdiffstats
path: root/contrib/python/matplotlib/py3/mpl_toolkits/axes_grid1/axes_rgb.py
blob: 52fd707e87043eb17929a84c53da2741722ce002 (plain) (tree)



























































































































































                                                                                
from types import MethodType

import numpy as np

from .axes_divider import make_axes_locatable, Size
from .mpl_axes import Axes, SimpleAxisArtist


def make_rgb_axes(ax, pad=0.01, axes_class=None, **kwargs):
    """
    Parameters
    ----------
    ax : `~matplotlib.axes.Axes`
        Axes instance to create the RGB Axes in.
    pad : float, optional
        Fraction of the Axes height to pad.
    axes_class : `matplotlib.axes.Axes` or None, optional
        Axes class to use for the R, G, and B Axes. If None, use
        the same class as *ax*.
    **kwargs
        Forwarded to *axes_class* init for the R, G, and B Axes.
    """

    divider = make_axes_locatable(ax)

    pad_size = pad * Size.AxesY(ax)

    xsize = ((1-2*pad)/3) * Size.AxesX(ax)
    ysize = ((1-2*pad)/3) * Size.AxesY(ax)

    divider.set_horizontal([Size.AxesX(ax), pad_size, xsize])
    divider.set_vertical([ysize, pad_size, ysize, pad_size, ysize])

    ax.set_axes_locator(divider.new_locator(0, 0, ny1=-1))

    ax_rgb = []
    if axes_class is None:
        axes_class = type(ax)

    for ny in [4, 2, 0]:
        ax1 = axes_class(ax.get_figure(), ax.get_position(original=True),
                         sharex=ax, sharey=ax, **kwargs)
        locator = divider.new_locator(nx=2, ny=ny)
        ax1.set_axes_locator(locator)
        for t in ax1.yaxis.get_ticklabels() + ax1.xaxis.get_ticklabels():
            t.set_visible(False)
        try:
            for axis in ax1.axis.values():
                axis.major_ticklabels.set_visible(False)
        except AttributeError:
            pass

        ax_rgb.append(ax1)

    fig = ax.get_figure()
    for ax1 in ax_rgb:
        fig.add_axes(ax1)

    return ax_rgb


class RGBAxes:
    """
    4-panel `~.Axes.imshow` (RGB, R, G, B).

    Layout::

        ┌───────────────┬─────┐
        │               │  R  │
        │               ├─────┤
        │      RGB      │  G  │
        │               ├─────┤
        │               │  B  │
        └───────────────┴─────┘

    Subclasses can override the ``_defaultAxesClass`` attribute.
    By default RGBAxes uses `.mpl_axes.Axes`.

    Attributes
    ----------
    RGB : ``_defaultAxesClass``
        The Axes object for the three-channel `~.Axes.imshow`.
    R : ``_defaultAxesClass``
        The Axes object for the red channel `~.Axes.imshow`.
    G : ``_defaultAxesClass``
        The Axes object for the green channel `~.Axes.imshow`.
    B : ``_defaultAxesClass``
        The Axes object for the blue channel `~.Axes.imshow`.
    """

    _defaultAxesClass = Axes

    def __init__(self, *args, pad=0, **kwargs):
        """
        Parameters
        ----------
        pad : float, default: 0
            Fraction of the Axes height to put as padding.
        axes_class : `~matplotlib.axes.Axes`
            Axes class to use. If not provided, ``_defaultAxesClass`` is used.
        *args
            Forwarded to *axes_class* init for the RGB Axes
        **kwargs
            Forwarded to *axes_class* init for the RGB, R, G, and B Axes
        """
        axes_class = kwargs.pop("axes_class", self._defaultAxesClass)
        self.RGB = ax = axes_class(*args, **kwargs)
        ax.get_figure().add_axes(ax)
        self.R, self.G, self.B = make_rgb_axes(
            ax, pad=pad, axes_class=axes_class, **kwargs)
        # Set the line color and ticks for the axes.
        for ax1 in [self.RGB, self.R, self.G, self.B]:
            if isinstance(ax1.axis, MethodType):
                ad = Axes.AxisDict(self)
                ad.update(
                    bottom=SimpleAxisArtist(ax1.xaxis, 1, ax1.spines["bottom"]),
                    top=SimpleAxisArtist(ax1.xaxis, 2, ax1.spines["top"]),
                    left=SimpleAxisArtist(ax1.yaxis, 1, ax1.spines["left"]),
                    right=SimpleAxisArtist(ax1.yaxis, 2, ax1.spines["right"]))
            else:
                ad = ax1.axis
            ad[:].line.set_color("w")
            ad[:].major_ticks.set_markeredgecolor("w")

    def imshow_rgb(self, r, g, b, **kwargs):
        """
        Create the four images {rgb, r, g, b}.

        Parameters
        ----------
        r, g, b : array-like
            The red, green, and blue arrays.
        **kwargs
            Forwarded to `~.Axes.imshow` calls for the four images.

        Returns
        -------
        rgb : `~matplotlib.image.AxesImage`
        r : `~matplotlib.image.AxesImage`
        g : `~matplotlib.image.AxesImage`
        b : `~matplotlib.image.AxesImage`
        """
        if not (r.shape == g.shape == b.shape):
            raise ValueError(
                f'Input shapes ({r.shape}, {g.shape}, {b.shape}) do not match')
        RGB = np.dstack([r, g, b])
        R = np.zeros_like(RGB)
        R[:, :, 0] = r
        G = np.zeros_like(RGB)
        G[:, :, 1] = g
        B = np.zeros_like(RGB)
        B[:, :, 2] = b
        im_rgb = self.RGB.imshow(RGB, **kwargs)
        im_r = self.R.imshow(R, **kwargs)
        im_g = self.G.imshow(G, **kwargs)
        im_b = self.B.imshow(B, **kwargs)
        return im_rgb, im_r, im_g, im_b