Source code for arpes.plotting.parameter

"""Utilities for plotting parameter data out of bulk fits."""

from __future__ import annotations

from typing import TYPE_CHECKING, Unpack

import matplotlib.pyplot as plt
from matplotlib.axes import Axes

from arpes.provenance import save_plot_provenance

from .utils import latex_escape

if TYPE_CHECKING:
    import numpy as np
    import xarray as xr
    from numpy.typing import NDArray

    from arpes._typing.plotting import MPLPlotKwargs

__all__ = ("plot_parameter",)


[docs] @save_plot_provenance def plot_parameter( # noqa: PLR0913 fit_data: xr.DataArray, param_name: str, ax: Axes | None = None, shift: float = 0, x_shift: float = 0, *, two_sigma: bool = False, figsize: tuple[float, float] = (7, 5), **kwargs: Unpack[MPLPlotKwargs], ) -> Axes: """Creates a scatter plot of a parameter from a `broadcast_fit` result. Args: fit_data (xr.DataArray): The fitting result, typically from `broadcast_fit.results`. param_name (str): The name of the parameter to plot. ax (Axes, optional): The axes on which to plot. If not provided, a new set of axes will be created. shift (float, optional): A vertical shift for the plot. Default is 0. x_shift (float, optional): A horizontal shift for the x-values. Default is 0. two_sigma (bool, optional): If True, plots the error bars as two standard deviations. Default is False. figsize (tuple[float, float], optional): The size of the figure. Default is (7, 5). kwargs: Additional keyword arguments for the plot (e.g., `color`, `markersize`, etc.). Returns: Axes: The Axes object with the plot. """ if ax is None: _, ax = plt.subplots(figsize=figsize) assert isinstance(ax, Axes) ds = fit_data.F.param_as_dataset(param_name) x_name = ds.value.dims[0] x: NDArray[np.floating] = ds.coords[x_name].values kwargs.setdefault("fillstyle", "none") kwargs.setdefault("markersize", 8) kwargs.setdefault("color", "#1f77b4") # matplotlib.colors.TABLEAU_COLORS["tab:blue"] e_width = None if "fmt" not in kwargs: kwargs["fmt"] = "" if two_sigma: _, _, lines = ax.errorbar( x + x_shift, ds.value.values + shift, yerr=2 * ds.error.values, elinewidth=1, **kwargs, ) e_width = 2 kwargs["markeredgewidth"] = 2 kwargs["color"] = lines[0].get_color()[0] kwargs["linewidth"] = 0 kwargs["fmt"] = "s" kwargs.setdefault("markeredgewidth", 2) ax.errorbar( x + x_shift, ds.value.values + shift, yerr=ds.error.values, elinewidth=e_width, **kwargs, ) ax.set_xlabel(latex_escape(x_name)) ax.set_ylabel(latex_escape(param_name)) return ax