"""Plotting routines for making the classic stacked line plots."""
from __future__ import annotations
import warnings
from logging import DEBUG, INFO
from typing import TYPE_CHECKING, Literal, Unpack
import matplotlib as mpl
import matplotlib.colorbar
import matplotlib.colors
import matplotlib.pyplot as plt
import matplotlib.ticker
import numpy as np
import xarray as xr
from matplotlib.axes import Axes
from matplotlib.colors import Colormap
from matplotlib.ticker import FixedLocator, MaxNLocator
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from arpes._typing.plotting import MPLPlotKwargsBasic
from arpes.analysis import rebin
from arpes.constants import TWO_DIMENSION
from arpes.debug import setup_logger
from arpes.provenance import save_plot_provenance
from arpes.utilities import normalize_to_spectrum
from .tof import scatter_with_std
from .utils import (
fancy_labels,
label_for_dim,
path_for_plot,
)
if TYPE_CHECKING:
from collections.abc import Callable
from pathlib import Path
from matplotlib.figure import Figure
from matplotlib.typing import ColorType
from numpy.typing import NDArray
from arpes._typing.base import ReduceMethod
from arpes._typing.plotting import LEGENDLOCATION, ColorbarParam, MPLPlotKwargsBasic
__all__ = (
"flat_stack_plot",
"offset_scatter_plot",
"stack_dispersion_plot",
"waterfall_dispersion",
)
LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[1]
logger = setup_logger(__name__, LOGLEVEL)
@save_plot_provenance
def waterfall_dispersion( # noqa: PLR0913
data: xr.DataArray,
scale_factor: float = 1.0,
stack_axis: str = "phi",
ax: Axes | None = None,
mode: Literal["fill_between", "hide_lines", "line"] = "line",
cmap: Colormap | str = "black",
figsize: tuple[float, float] = (7, 5),
prune: Literal["lower", "uppder", "both"] | None = "both",
*,
reverse: bool = True,
**kwargs: Unpack[MPLPlotKwargsBasic],
) -> tuple[Figure | None, Axes, Axes] | tuple[Figure | None, Axes]:
"""Plot a waterfall-style dispersion using 2D `xarray.DataArray`.
Each line profile along one axis is offset vertically according to the values of the stacking
axis, allowing visual inspection of variations across slices. A twin y-axis is added on the
right to indicate the original values of the stacking coordinate.
Args:
data (xr.DataArray): A 2D DataArray to plot. Must have exactly two dimensions.
scale_factor (float, optional): Scaling factor for vertical offset between stacks.
Must be positive, if 0 returns the 'flat stack" version. Defaults to 1.0.
stack_axis (str, optional): Name of the dimension along which stacking is performed.
Defaults to "phi".
ax (Axes, optional): Matplotlib Axes object to plot into. If None, a new figure and axes
will be created. Defaults to None.
mode (Literal["fill_between", "hide_line", "line"], optional):
Plotting style for each line:
- "line": lines only
- "fill_between": area between lines and offset baseline is filled with color
- "hide_line": lines are hidden by white fill overlaid
Defaults to "line".
prune ({'lower', 'upper', 'both', None}):
Remove the 'lower' tick, the 'upper' tick, or ticks on 'both' sides
*if they fall exactly on the **right** axis edge*. Default "both"
reverse (bool): Whether the stacking direction is reversed (i.e., from top to bottom).
cmap (Colormap | str, optional): A matplotlib colormap name or single color string to use.
Defaults to "black".
figsize (tuple[float, float], optional): Figure size (ignored if `ax` is provided).
Defaults to (7, 5).
reverse (bool, optional): Whether to reverse the stacking direction. Defaults to True.
**kwargs: Additional keyword arguments passed to `ax.plot()` and `fill_between()`.
Returns:
tuple[Figure | None, Axes, Axes]:
Tuple of the figure (if created), the main axes (left y-axis), and the twin axes
(right y-axis).
Raises:
AssertionError: If `data` is not 2D or `scale_factor` is not positive.
Notes:
This waterfall does not have 'nbins' functionality.
The default style of the label is same as the default output of S.plot.
Use the following example, when the label text, especially for the right
axis label, is modified.
.. code-block:: python
label = ax_right.yaxis.label
label.set_text("new label text")
label.set_fontsize(12)
or just
.. code-block:: python
ax_right.yaxis.label.set_text(stack_axis)
"""
assert data.ndim == TWO_DIMENSION
assert scale_factor >= 0, "scale factor should be positive."
fig: Figure | None = None
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
stack_axis_values = data.coords[stack_axis].values
n_stacks = len(stack_axis_values)
bottom_stack = stack_axis_values[-1] if reverse else stack_axis_values[0]
_ = set(data.dims)
_.remove(stack_axis)
plot_axis_name = _.pop()
plot_axis = data.coords[plot_axis_name].values
def laxis_to_right(laxis_value: float) -> float:
if reverse:
return (scale_factor * bottom_stack - laxis_value) / scale_factor
return (scale_factor * bottom_stack + laxis_value) / scale_factor
def raxis_to_left(raxis_value: float) -> float:
return scale_factor * abs(raxis_value - bottom_stack)
colors: list[tuple[float, float, float, float]] | list[str] = _get_colors(
cmap=cmap,
n_stacks=n_stacks,
)
alpha = kwargs.get("alpha", 1)
for i, stack_axis_value in enumerate(data.G.iter_coords(stack_axis, reverse=reverse)):
offset = raxis_to_left(stack_axis_value[stack_axis])
y = data.sel(stack_axis_value).values + offset
kwargs["alpha"] = 1
kwargs["color"] = colors[i]
ax.plot(
plot_axis,
y,
zorder=2 * (n_stacks - i) + 1,
**kwargs,
)
if mode == "hide_lines":
kwargs["alpha"] = alpha
kwargs["color"] = "white"
ax.fill_between(
plot_axis,
y,
offset,
zorder=2 * (n_stacks - i),
**kwargs,
)
kwargs["color"] = colors[i]
if mode == "fill_between":
kwargs["alpha"] = alpha
kwargs["color"] = colors[i]
ax.fill_between(
plot_axis,
y,
offset,
zorder=2 * (n_stacks - i),
**kwargs,
)
# set default values.
if data.name:
ax.set_ylabel(str(data.name))
ax.set_xlabel(str(plot_axis_name))
if scale_factor <= 0:
return fig, ax
# Right axis
ax_right = _set_right_axis(
ax=ax,
stack_coords=data.coords[stack_axis],
axis_converters=(laxis_to_right, raxis_to_left),
prune=prune,
reverse=reverse,
)
return fig, ax, ax_right
def _get_colors(
cmap: Colormap | str,
n_stacks: int,
) -> list[tuple[float, float, float, float]] | list[str]:
if isinstance(cmap, str):
try:
cmap_ = plt.colormaps[cmap]
return [cmap_(i / (n_stacks - 1)) for i in range(n_stacks)]
except KeyError:
return [cmap for _ in range(n_stacks)]
else: # should be colormaps
return [cmap(i / (n_stacks - 1)) for i in range(n_stacks)]
def _set_right_axis(
ax: Axes,
stack_coords: xr.DataArray,
axis_converters: tuple[Callable[[float], float], Callable[[float], float]],
prune: Literal["lower", "uppder", "both"] | None,
*,
reverse: bool,
) -> Axes:
"""Add and configure a right-side y-axis that reflects the stacking axis values.
This function creates a twin y-axis (`ax_right`) for a waterfall-style plot, where each
stacked trace is offset vertically but corresponds to a value in the original `stack_axis`.
It synchronizes the scaling with the left axis and maps tick positions accordingly.
Args:
ax (Axes): The main matplotlib Axes on the left side.
stack_coords (xr.DataArray): The coordinates of the stacking axis.
axis_converters (tuple[Callable[[float], float], Callable[[float], float]]):
Functions to convert a left-axis coordinate to right-axis and vice versa.
prune(Literal['lower', 'upper', 'both'] | None):
Remove the 'lower' tick, the 'upper' tick, or ticks on 'both' sides
*if they fall exactly on the **right** axis edge*, default: None
reverse (bool): Whether the stacking direction is reversed (i.e., from top to bottom).
Returns:
Axes: The configured right-side twin Axes (`ax_right`).
"""
ax_right = ax.twinx()
laxis_bottom, laxis_top = ax.get_ylim()
laxis_to_right = axis_converters[0]
raxis_to_left = axis_converters[1]
ax_right.set_ylim(laxis_to_right(laxis_bottom), laxis_to_right(laxis_top))
# right axis ticks
stack_axis_values = stack_coords.values
stack_axis = str(stack_coords.name)
rticks = MaxNLocator(nbins=5, prune=prune).tick_values(
vmin=np.min(stack_axis_values),
vmax=np.max(stack_axis_values),
)
lticks = [raxis_to_left(raxis_value) for raxis_value in rticks]
if reverse:
lticks.reverse()
ax_right.yaxis.set_major_locator(FixedLocator(rticks))
# Tune the right axis label position
ax_right.set_ylabel(stack_axis)
ylim = ax.get_ylim()
ycenter = (min(lticks) + max(lticks)) / 2
ycoords = (ycenter - ylim[0]) / (ylim[1] - ylim[0])
ax_right.yaxis.set_label_coords(1.07, ycoords)
return ax_right
[docs]
@save_plot_provenance
def offset_scatter_plot( # noqa: PLR0913
data: xr.Dataset,
name_to_plot: str = "",
stack_axis: str = "",
ax: Axes | None = None,
out: str | Path = "",
scale_coordinate: float = 0.5,
ylim: tuple[float, float] | tuple[()] = (),
fermi_level: float | None = None,
loc: LEGENDLOCATION = "upper left",
figsize: tuple[float, float] = (11, 5),
*,
color: Colormap | str = "black",
aux_errorbars: bool = True,
**kwargs: Unpack[ColorbarParam],
) -> Path | tuple[Figure | None, Axes]:
"""Makes a stack plot (scatters version).
Args:
data(xr.Dataset): The dataset containing the data to plot.
name_to_plot(str): Name of the spectrum (in many case 'spectrum') to plot, by default "".
stack_axis(str): The axis along which to stack the plot, by default "".
ax(Axes | None): The axes on which to plot, by default None.
out(str | Path): The output path for the plot, by default "".
scale_coordinate(float): The scale coordinate, by default 0.5
ylim(tuple[float, float]): The y-axis limits, by default ()
fermi_level(float | None): The Fermi level to draw the line, by default None (not drawn).
figsize (tuple[float, float]) : The figure size, by default (11, 5)
loc(LEGENDLOCATION): The locatio of the legend, by default "upper left".
color: The color of the plot. Colormap can be set. Default to "black".
aux_errorbars(bool): Whether to include auxiliary error bars, by default True
kwargs: kwargs passing to args of Colorbar
Returns:
Path | tuple[Figure | None, Axes]: The path to the saved plot or the figure and axes.
Raises:
ValueError
"""
assert isinstance(data, xr.Dataset)
if not name_to_plot:
var_names = [k for k in data.data_vars if "_std" not in str(k)] # => ["spectrum"]
assert len(var_names) == 1
name_to_plot = str(var_names[0])
assert (name_to_plot + "_std") in data.data_vars, "Has 'mean_and_deviation' been applied?"
msg = "In order to produce a stack plot, data must be image-like."
msg += "Passed data included dimensions:"
msg += f" {data.data_vars[name_to_plot].dims}"
assert len(data.data_vars[name_to_plot].dims) == TWO_DIMENSION, msg
fig: Figure | None = None
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
inset_ax = inset_axes(ax, width="40%", height="5%", loc=loc)
assert isinstance(ax, Axes)
stack_axis = stack_axis or str(data.data_vars[name_to_plot].dims[0])
skip_colorbar = True
other_dim = next(str(d) for d in data.dims if d != stack_axis)
if "eV" in data.dims and stack_axis != "eV" and fermi_level is not None:
ax.axhline(fermi_level, linestyle="--", color="red")
ax.fill_betweenx([-1e6, 1e6], 0, 0.2, color="black", alpha=0.07)
if not ylim:
ax.set_ylim(auto=True)
else:
ax.set_ylim(bottom=ylim[0], top=ylim[1])
ylim = ax.get_ylim()
# real plotting here
for i, coord in enumerate(data.G.iter_coords(stack_axis)):
value = data.sel(coord)
delta = data.G.stride(generic_dim_names=False)[other_dim]
data_for = value.copy(deep=True)
data_for.coords[other_dim].values -= i * delta * scale_coordinate / 10
scatter_with_std(
data_for,
name_to_plot,
ax=ax,
color=_color_for_plot(color, i, len(data.coords[stack_axis])),
)
if aux_errorbars:
data_for = data_for.copy(deep=True)
flattened = data_for.data_vars[name_to_plot].copy(deep=True)
flattened.values = ylim[0] * np.ones(flattened.values.shape)
data_for = data_for.assign({name_to_plot: flattened})
scatter_with_std(
data_for,
name_to_plot,
ax=ax,
color=_color_for_plot(color, i, len(data.coords[stack_axis])),
)
ax.set_xlabel(other_dim)
ax.set_ylabel(name_to_plot)
fancy_labels(ax)
kwargs = _set_default_kwargs(kwargs, data=data, stack_axis=stack_axis)
if isinstance(color, Colormap):
kwargs.setdefault("cmap", color)
if inset_ax and not skip_colorbar:
inset_ax.set_xlabel(stack_axis, fontsize=16)
fancy_labels(inset_ax)
matplotlib.colorbar.Colorbar(
inset_ax,
**kwargs,
)
if out:
plt.savefig(path_for_plot(out), dpi=400)
return path_for_plot(out)
return fig, ax
def _set_default_kwargs(
kwargs: ColorbarParam,
data: xr.Dataset,
stack_axis: str,
) -> ColorbarParam:
kwargs.setdefault("orientation", "horizontal")
kwargs.setdefault(
"label",
label_for_dim(data, stack_axis),
)
kwargs.setdefault(
"norm",
matplotlib.colors.Normalize(
vmin=data.coords[stack_axis].min().item(),
vmax=data.coords[stack_axis].max().item(),
),
)
kwargs.setdefault("ticks", matplotlib.ticker.MaxNLocator(2))
return kwargs
[docs]
@save_plot_provenance
def flat_stack_plot( # noqa: PLR0913 #pragma: no cover
data: xr.DataArray,
*,
stack_axis: str = "",
ax: Axes | None = None,
mode: Literal["line", "scatter"] = "line",
fermi_level: float | None = None,
figsize: tuple[float, float] = (7, 5),
title: str = "",
max_stacks: int = 200,
out: str | Path = "",
loc: LEGENDLOCATION = "upper left",
**kwargs: Unpack[MPLPlotKwargsBasic],
) -> tuple[Figure | None, Axes] | Path:
"""Generates a stack plot with all the lines distinguished by color rather than offset.
Args:
data(DataType): ARPES data (xr.DataArray is prepfered)
stack_axis(str): axis for stacking, by default ""
ax (Axes | None): matplotlib Axes, by default None.j
mode(Literal["line", "scatter"]): plot style (line/sckatter), by default "line".
fermi_level(float|None): Value of the Fermi level to Draw the line, by default None.
figsize (tuple[float, float]): Figure size, by default (7, 5).
title(str): Title string, by default ""
max_stacks(int): Maximum number of the staking spectra, by default 200.
out(str | Path): Path to the figure, by default "".
loc(LEGENDLOCATION): Legend location, by default "upper left".
**kwargs: Additional keyword to pass to ax.plot
Returns:
Path | tuple[Figure | None, Axes]: The figure and axes of the path to the saved plot.
Raises:
ValueError: If there is an issue with the input data.
NotImplementedError: If a feature is not implemented.
"""
warnings.warn(
"This method will be deprecated. Use waterfall_dispersion with scaling_facotor=0 instead.",
category=DeprecationWarning,
stacklevel=2,
)
data = _rebinning(
data,
stack_axis=stack_axis,
max_stacks=max_stacks,
method="mean",
)[0]
fig: Figure | None = None
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
ax_inset = inset_axes(ax, width="40%", height="5%", loc=loc)
assert isinstance(ax, Axes)
if not stack_axis:
stack_axis = str(data.dims[0])
horizontal_dim = next(str(d) for d in data.dims if d != stack_axis)
horizontal = data.coords[horizontal_dim]
if "eV" in data.dims and stack_axis != "eV" and fermi_level is not None:
ax.axvline(
fermi_level,
color="red",
alpha=0.8,
linestyle="--",
linewidth=1,
)
color = kwargs.pop("color", "viridis")
for i, coord in enumerate(data.G.iter_coords(stack_axis)):
marginal = data.sel(coord, method="nearest")
if mode == "line":
kwargs["color"] = _color_for_plot(color, i, len(data.coords[stack_axis]))
ax.plot(
horizontal,
marginal.values,
**kwargs,
)
else:
assert mode == "scatter"
kwargs["color"] = _color_for_plot(color, i, len(data.coords[stack_axis]))
ax.scatter(horizontal, marginal.values, **kwargs)
assert isinstance(color, str | Colormap)
matplotlib.colorbar.Colorbar(
ax_inset,
orientation="horizontal",
label=label_for_dim(data, stack_axis),
norm=matplotlib.colors.Normalize(
vmin=data.coords[stack_axis].min().values,
vmax=data.coords[stack_axis].max().values,
),
ticks=matplotlib.ticker.MaxNLocator(2),
cmap=color,
)
ax.set_xlabel(label_for_dim(data, horizontal_dim))
ax.set_ylabel("Spectrum Intensity (arb).")
ax.set_title(title, fontsize=14)
ax.set_xlim(left=horizontal.min().item(), right=horizontal.max().item())
if out:
plt.savefig(path_for_plot(out), dpi=400)
return path_for_plot(out)
return fig, ax
[docs]
@save_plot_provenance
def stack_dispersion_plot( # noqa: PLR0913 # pragma: no cover
data: xr.DataArray,
*,
stack_axis: str = "",
ax: Axes | None = None,
out: str | Path = "",
max_stacks: int = 100,
scale_factor: float = 0,
mode: Literal["fill_between", "hide_line", "line"] = "line",
offset_correction: Literal["zero", "constant", "constant_right"] | None = "zero",
shift: float = 0,
negate: bool = False,
figsize: tuple[float, float] = (7.0, 7.0),
title: str = "",
**kwargs: Unpack[MPLPlotKwargsBasic],
) -> Path | tuple[Figure | None, Axes]:
"""Generates a stack plot with all the lines distinguished by offset (and color).
Args:
data(XrTypes): ARPES data
stack_axis(str): stack axis. e.g. "phi" , "eV", ...
ax(Axes): matplotlib Axes object
out(str | Path): Path for output figure
max_stacks(int): maximum number of the stacking spectra
scale_factor(float): scale factor
mode(Literal["liine", "fill_between", "hide_line", "scatter"]): Draw mode
offset_correction(Literal["zero", "constant", "constant_right"] | None): offset correction
mode (default to
"zero")
shift(float): shift of the plot along the horizontal direction
figsize (tuple[float, float]): figure size, default is (7.0, 7.0)
title (str, optional): title of figure
negate(bool): _description_
**kwargs: Passed to ax.plot / fill_between. Can set linewidth etc., here.
(See _typing/MPLPlotKwagsBasic)
"""
warnings.warn(
"This method will be deprecated. "
" Use waterfall_dispersion instead; its simpler design makes it much easier to use.",
category=DeprecationWarning,
stacklevel=2,
)
data_arr, stack_axis, other_axis = _rebinning(
data,
stack_axis=stack_axis,
max_stacks=max_stacks,
)
fig: Figure | None = None
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
assert isinstance(ax, Axes)
if not title:
title = (
f"ID: {data_arr.S.parent_id} Stack"
if data_arr.S.parent_id
else f"{data_arr.S.label.replace('_', ' ')} Stack"
)
max_intensity_over_stacks = np.nanmax(data_arr.values)
cvalues: NDArray[np.floating] = data_arr.coords[other_axis].values
if not scale_factor:
scale_factor = _scale_factor(
data_arr,
stack_axis=stack_axis,
offset_correction=offset_correction,
negate=negate,
)
lim = [np.inf, -np.inf]
color = kwargs.pop("color", "black")
for i, coord_dict in enumerate(data_arr.G.iter_coords(stack_axis, reverse=True)):
coord_value = coord_dict[stack_axis]
ys = _y_shifted(
offset_correction=offset_correction,
coord_value=coord_value,
marginal=data_arr.sel(coord_dict),
scale_parameters=(scale_factor, max_intensity_over_stacks, negate),
)
xs = cvalues - i * shift
lim = [min(lim[0], float(np.min(xs))), max(lim[1], float(np.max(xs)))]
if mode == "line":
kwargs["color"] = _color_for_plot(color, i, len(data_arr.coords[stack_axis]))
ax.plot(xs, ys, **kwargs)
elif mode == "hide_line":
kwargs["color"] = _color_for_plot(color, i, len(data_arr.coords[stack_axis]))
ax.plot(xs, ys, **kwargs, zorder=i * 2 + 1)
kwargs["color"] = "white"
kwargs["alpha"] = 1
ax.fill_between(xs, ys, coord_value, zorder=i * 2, **kwargs)
elif mode == "fill_between":
kwargs["color"] = _color_for_plot(color, i, len(data_arr.coords[stack_axis]))
kwargs["alpha"] = 1
ax.fill_between(xs, ys, coord_value, zorder=i * 2, **kwargs)
else:
kwargs["color"] = _color_for_plot(color, i, len(data_arr.coords[stack_axis]))
ax.scatter(xs, ys, **kwargs)
x_label, y_label = other_axis, stack_axis
yticker = matplotlib.ticker.MaxNLocator(5)
y_tick_region = [
i
for i in yticker.tick_values(
data_arr.coords[stack_axis].min().item(),
data_arr.coords[stack_axis].max().item(),
)
if (
i > data_arr.coords[stack_axis].min().item()
and i < data_arr.coords[stack_axis].max().item()
)
]
ax.set_yticks(np.array(y_tick_region))
ax.set_ylabel(label_for_dim(data_arr, y_label))
ylims = ax.get_ylim()
median_along_stack_axis = y_tick_region[2]
ax.yaxis.set_label_coords(
-0.09,
1 / (ylims[1] - ylims[0]) * (median_along_stack_axis - ylims[0]),
)
ax.set_xlabel(label_for_dim(data_arr, x_label))
# set xlim with margin
# 11/10 is the good value for margine
axis_min, axis_max = min(lim), max(lim)
middle = (axis_min + axis_max) / 2
ax.set_xlim(
left=middle - (axis_max - axis_min) / 2 * 11 / 10,
right=middle + (axis_max - axis_min) / 2 * 11 / 10,
)
ax.set_title(title)
if out:
plt.savefig(path_for_plot(out), dpi=400)
return path_for_plot(out)
return fig, ax
def _y_shifted(
offset_correction: Literal["zero", "constant", "constant_right"] | None,
marginal: xr.DataArray,
coord_value: NDArray[np.floating],
scale_parameters: tuple[float, float, bool],
) -> NDArray[np.floating]:
scale_factor = scale_parameters[0]
max_intensity_over_stacks = scale_parameters[1]
negate = scale_parameters[2]
marginal_values = -marginal.values if negate else marginal.values
marginal_offset, right_marginal_offset = marginal_values[0], marginal_values[-1]
if offset_correction == "zero":
true_ys = marginal_values / max_intensity_over_stacks
elif offset_correction == "constant":
true_ys = (marginal_values - marginal_offset) / max_intensity_over_stacks
elif offset_correction == "constant_right":
true_ys = (marginal_values - right_marginal_offset) / max_intensity_over_stacks
else: # is this procedure phyically correct?
true_ys = (
marginal_values
- np.linspace(marginal_offset, right_marginal_offset, len(marginal_values))
) / max_intensity_over_stacks
return scale_factor * true_ys + coord_value
def _scale_factor(
data_arr: xr.DataArray,
stack_axis: str,
*,
offset_correction: Literal["zero", "constant", "constant_right"] | None = "zero",
negate: bool = False,
) -> float:
"""Determine the scale factor."""
maximum_deviation = -np.inf
for coords in data_arr.G.iter_coords(stack_axis):
marginal = data_arr.sel(coords, method="nearest")
marginal_values = -marginal.values if negate else marginal.values
marginal_offset, right_marginal_offset = marginal_values[0], marginal_values[-1]
if offset_correction == "zero":
true_ys = marginal_values
elif offset_correction is not None and offset_correction.startswith("constant"):
true_ys = marginal_values - marginal_offset
else:
true_ys = marginal_values - np.linspace(
marginal_offset,
right_marginal_offset,
len(marginal_values),
)
maximum_deviation = np.max([maximum_deviation, *np.abs(true_ys)])
return float(
10.0
* (data_arr.coords[stack_axis].max() - data_arr.coords[stack_axis].min()).item()
/ maximum_deviation,
)
def _rebinning(
data: xr.DataArray,
stack_axis: str,
max_stacks: int,
method: ReduceMethod = "sum",
) -> tuple[xr.DataArray, str, str]:
"""Preparation for stack plot.
1. rebinning
2. determine the stack axis
3. determine the name of the other.
"""
data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
assert isinstance(data_arr, xr.DataArray)
if len(data.dims) != TWO_DIMENSION:
msg = "In order to produce a stack plot, data must be image-like."
msg += f"Passed data included dimensions: {data.dims}"
raise IndexError(
msg,
)
if not stack_axis:
stack_axis = str(data_arr.dims[0])
other_axes = list(data_arr.dims)
other_axes.remove(stack_axis)
horizontal_axis = str(other_axes[0])
stack_coord: xr.DataArray = data_arr.coords[stack_axis]
return (
rebin(
data_arr,
bin_width={stack_axis: int(np.ceil(len(stack_coord.values) / max_stacks))},
method=method,
),
stack_axis,
horizontal_axis,
)
def _color_for_plot(
color: Colormap | ColorType,
i: int,
num_plot: int,
) -> ColorType:
if isinstance(color, Colormap):
cmap = color
return cmap(np.abs(i / num_plot))
if isinstance(color, str):
try:
cmap = mpl.colormaps[color]
return cmap(np.abs(i / num_plot))
except KeyError: # not in the colormap name, assume the color name
return color
return color # color is tuple representing the color