"""Annotations onto plots for experimental conditions or locations."""
from __future__ import annotations
from typing import TYPE_CHECKING, Literal, Unpack
import matplotlib as mpl
import numpy as np
import xarray as xr
from matplotlib.axes import Axes
from mpl_toolkits.mplot3d import Axes3D
from arpes.constants import TWO_DIMENSION
from .utils import name_for_dim, unit_for_dim
if TYPE_CHECKING:
from collections.abc import Sequence
from numpy.typing import NDArray
from arpes._typing.attrs_property import ExperimentInfo
from arpes._typing.base import XrTypes
from arpes._typing.plotting import MPLTextParam
__all__ = (
"annotate_cuts",
"annotate_experimental_conditions",
"annotate_point",
)
font_scalings = { # see matplotlib.font_manager
"xx-small": 0.579,
"x-small": 0.694,
"small": 0.833,
"medium": 1.0,
"large": 1.200,
"x-large": 1.440,
"xx-large": 1.728,
"larger": 1.2,
"smaller": 0.833,
}
# TODO @<R.Arafune>: Useless: Revision required
# * In order not to use data axis, set transform = ax.Transform
[docs]
def annotate_experimental_conditions(
ax: Axes,
data: XrTypes,
desc: list[str | float] | float | str,
*,
show: bool = False,
orientation: Literal["top", "bottom"] = "top",
**kwargs: Unpack[MPLTextParam],
) -> None:
"""Renders information about the experimental conditions onto a set of axes.
Also adjust the axes limits and hides the axes.
data should be the dataset described, and desc should be one of
'temp',
'photon',
'photon polarization',
'polarization',
or a number to act as a spacer in units of the axis coordinates
or a list of such items.
"""
if isinstance(desc, str | int | float):
desc = [desc]
ax.grid(visible=False)
ax.set_ylim(bottom=0, top=100)
ax.set_xlim(left=0, right=100)
if not show:
ax.set_axis_off()
ax.patch.set_alpha(0)
delta: float = -1
current = 100.0
if orientation == "bottom":
delta = 1
current = 0
fontsize_keyword: (
float
| Literal[
"xx-small",
"x-small",
"small",
"medium",
"large",
"x-large",
"xx-large",
"larger",
"smaller",
]
) = kwargs.get("fontsize", 16)
if isinstance(fontsize_keyword, float):
fontsize = fontsize_keyword
elif fontsize_keyword in font_scalings:
fontsize = mpl.rc_params()["font.size"] * font_scalings[fontsize_keyword]
else:
err_msg = "Incorrect font size setting"
raise RuntimeError(err_msg)
delta = fontsize * delta
conditions: ExperimentInfo = data.S.experimental_conditions
renderers = {
"temp": lambda c: "\\textbf{T = " + "{:.3g}".format(c["temp"]) + " K}",
"photon": _render_photon,
"hv": _render_photon,
"photon polarization": lambda c: _render_photon(c) + ", " + _render_polarization(c),
"polarization": _render_polarization,
}
for item in desc:
if isinstance(item, float):
current += item + delta
continue
item_replaced = item.replace("_", " ").lower()
ax.text(0, current, renderers[item_replaced](conditions), **kwargs)
current += delta
def _render_polarization(conditions: dict[str, str]) -> str:
pol = conditions["polarization"]
if pol in {"lc", "rc"}:
return "\\textbf{" + pol.upper() + "}"
symbol_pol: dict[str, str] = {
"s": "",
"p": "",
"s-p": "",
"p-s": "",
}
prefix = ""
if pol in {"s-p", "p-s"}:
prefix = "\\textbf{Linear Dichroism, }"
symbol = symbol_pol[pol]
if symbol:
return prefix + "$" + symbol + "$/\\textbf{" + pol + "}"
return prefix + "\\textbf{" + pol + "}"
def _render_photon(c: dict[str, float]) -> str:
return "\\textbf{" + str(c["hv"]) + " eV}"
[docs]
def annotate_cuts(
ax: Axes,
data: XrTypes,
plotted_axes: NDArray[np.object_],
*,
include_text_labels: bool = False,
**kwargs: tuple[float, ...] | list[float] | NDArray[np.float64],
) -> None:
"""Annotates a cut location onto a plot.
Example:
>>> annotate_cuts(ax, conv, ['kz', 'ky'], hv=80) # doctest: +SKIP
Args:
ax: The axes to plot onto
data: The original data
plotted_axes: The dimension names which were plotted
include_text_labels: Whether to include text labels
kwargs: Defines the coordinates of the cut location
"""
# NOTE:
# Local import is required to avoid import-time circular dependency:
# plotting -> analysis -> xarray_extensions -> plotting
from arpes.analysis.forward_conversion import ( # noqa: PLC0415
convert_coordinates_to_kspace_forward,
)
converted_coordinates = convert_coordinates_to_kspace_forward(data)
assert isinstance(converted_coordinates, xr.Dataset)
assert len(plotted_axes) == TWO_DIMENSION
for k, v in kwargs.items():
selected = converted_coordinates.sel({k: v}, method="nearest")
for coords_dict in selected.G.iter_coords(k):
obj = selected.sel(coords_dict, method="nearest")
css = [obj[d].values for d in plotted_axes]
ax.plot(*css, color="red", ls="--", linewidth=1, dashes=(5, 5))
if include_text_labels:
idx = np.argmin(css[1])
ax.text(
css[0][idx] + 0.05,
css[1][idx],
f"{name_for_dim(k)} = {coords_dict[k].item()} {unit_for_dim(k)}",
color="red",
size="medium",
)
[docs]
def annotate_point(
ax: Axes | Axes3D,
location: Sequence[float],
delta: tuple[float, float] | tuple[float, float, float] = (-0.05, 0.05),
**kwargs: Unpack[MPLTextParam],
) -> None:
"""Annotates a point or high symmetry location into a plot."""
if "label" in kwargs:
label = {
"G": "$\\Gamma$",
"X": r"\textbf{X}",
"Y": r"\textbf{Y}",
"K": r"\textbf{K}",
"M": r"\textbf{M}",
}.get(kwargs["label"], "")
kwargs.pop("label")
assert isinstance(delta, tuple)
if "color" not in kwargs:
kwargs["color"] = "red"
if len(delta) == TWO_DIMENSION:
assert isinstance(ax, Axes)
dx, dy = tuple(delta)
pos_x, pos_y = tuple(location)
ax.plot(
[pos_x],
[pos_y],
"o",
c=kwargs["color"],
)
ax.text(
pos_x + dx,
pos_y + dy,
s=label,
**kwargs,
)
else:
assert isinstance(ax, Axes3D)
dx, dy, dz = tuple(delta)
pos_x, pos_y, pos_z = tuple(location)
ax.plot(
[pos_x],
[pos_y],
[pos_z],
"o",
c=kwargs["color"],
)
ax.text(
pos_x + dx,
pos_y + dy,
pos_z + dz,
label,
**kwargs,
)