Source code for arpes.plotting.spatial

"""Some common spatial plotting routines. Useful for contextualizing nanoARPES data."""

from __future__ import annotations

import contextlib
import itertools
from typing import TYPE_CHECKING, Any

import matplotlib as mpl
import matplotlib.patheffects as path_effects
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from adjustText import adjust_text
from matplotlib import gridspec, patches

from arpes.constants import TWO_DIMENSION
from arpes.io import load_data
from arpes.provenance import save_plot_provenance
from arpes.utilities import normalize_to_spectrum
from arpes.utilities.xarray import unwrap_xarray_item

from .annotations import annotate_point
from .utils import (
    ddata_daxis_units,
    fancy_labels,
    frame_with,
    path_for_plot,
    remove_colorbars,
)

if TYPE_CHECKING:
    from collections.abc import Hashable, Sequence
    from pathlib import Path

    from matplotlib.axes import Axes
    from matplotlib.figure import Figure
    from numpy.typing import NDArray

    from arpes._typing.base import DataType

__all__ = ("plot_spatial_reference", "reference_scan_spatial")


[docs] @save_plot_provenance def plot_spatial_reference( # noqa: PLR0913, C901, PLR0912, PLR0915 # Might be removed in the future. reference_map: xr.DataArray, data_list: list[DataType], offset_list: Sequence[dict[str, Any] | None] | None = None, annotation_list: list[str] | None = None, out: str | Path = "", *, plot_refs: bool = True, ) -> Path | tuple[Figure, list[Axes]]: """Helpfully plots data against a reference scanning dataset. This is essential to understand where data was taken and can be used early in the analysis phase in order to highlight the location of your datasets against core levels, etc. Args: reference_map: A scanning photoemission like dataset data_list: A list of datasets you want to plot the relative locations of offset_list: Optionally, offsets given as coordinate dicts annotation_list: Optionally, text annotations for the data out: Where to save the figure if we are outputting to disk plot_refs: Whether to plot reference figures for each of the pieces of data in `data_list` """ if offset_list is None: offset_list = [None] * len(data_list) if annotation_list is None: annotation_list = [str(i + 1) for i in range(len(data_list))] if not isinstance(reference_map, xr.DataArray): reference_map = normalize_to_spectrum(reference_map) n_references = len(data_list) ax_refs: list[Axes] if n_references == 1 and plot_refs: fig, axes = plt.subplots(1, 2, figsize=(12, 5)) ax = axes[0] ax_refs = [axes[1]] elif plot_refs: n_extra_axes = 1 + (n_references // 4) fig = plt.figure(figsize=(6 * (1 + n_extra_axes), 5)) spec = gridspec.GridSpec(ncols=2 * (1 + n_extra_axes), nrows=2, figure=fig) ax = fig.add_subplot(spec[:2, :2]) ax_refs = [ fig.add_subplot(spec[i // (2 * n_extra_axes), 2 + i % (2 * n_extra_axes)]) for i in range(n_references) ] else: ax_refs = [] fig, ax = plt.subplots(1, 1, figsize=(6, 5)) with contextlib.suppress(Exception): reference_map = reference_map.S.spectra[0] reference_map = reference_map.S.mean_other(["x", "y", "z"]) ref_dims: tuple[Hashable, ...] = reference_map.dims[::-1] assert len(reference_map.dims) == TWO_DIMENSION reference_map.S.plot(ax=ax, cmap="Blues") cmap = mpl.colormaps.get_cmap("Reds") rendered_annotations = [] for i, (data, offset, annotation) in enumerate( zip(data_list, offset_list, annotation_list, strict=True), ): if offset is None: try: logical_offset = { "x": (data.S.logical_offsets["x"] - reference_map.S.logical_offsets["x"]), "y": (data.S.logical_offsets["y"] - reference_map.S.logical_offsets["z"]), "z": (data.S.logical_offsets["y"] - reference_map.S.logical_offsets["z"]), } except ValueError: logical_offset = {} else: logical_offset = offset coords = {c: unwrap_xarray_item(data.coords[c]) for c in ref_dims} n_array_coords = len( [cv for cv in coords.values() if isinstance(cv, np.ndarray | xr.DataArray)], ) color = cmap(0.4 + (0.5 * i / len(data_list))) x = coords[ref_dims[0]] + logical_offset.get(str(ref_dims[0]), 0) y = coords[ref_dims[1]] + logical_offset.get(str(ref_dims[1]), 0) ref_x, ref_y = x, y off_x, off_y = 0, 0 scale = 0.03 if n_array_coords == 0: off_y = 1 ax.scatter([x], [y], s=60, color=color) if n_array_coords == 1: if isinstance(x, np.ndarray | xr.DataArray): y = [y] * len(x) ref_x = np.min(x) off_x = -1 else: x = [x] * len(y) ref_y = np.max(y) off_y = 1 ax.plot(x, y, color=color, linewidth=3) if n_array_coords == TWO_DIMENSION: off_y = 1 min_x, max_x = np.min(x), np.max(x) min_y, max_y = np.min(y), np.max(y) ref_x, ref_y = min_x, max_y color = cmap(0.4 + (0.5 * i / len(data_list)), alpha=0.5) rect = patches.Rectangle((min_x, min_y), max_x - min_x, max_y - min_y, facecolor=color) color = cmap(0.4 + (0.5 * i / len(data_list))) ax.add_patch(rect) dp = ddata_daxis_units(ax) text_location = ( np.asarray([ref_x, ref_y]) + dp * scale * np.asarray([off_x, off_y]) ).tolist() text = ax.annotate(annotation, text_location, color="black", size=15) rendered_annotations.append(text) text.set_path_effects( [path_effects.Stroke(linewidth=2, foreground="white"), path_effects.Normal()], ) if plot_refs: ax_ref = ax_refs[i] keep_preference = [ *list(ref_dims), "eV", "temperature", "kz", "hv", "kp", "kx", "ky", "phi", "theta", "beta", "pixel", ] keep = [d for d in keep_preference if d in data.dims][:2] data.S.mean_other(keep).S.plot(ax=ax_ref) ax_ref.set_title(annotation) fancy_labels(ax_ref) frame_with(ax_ref, color=color, linewidth=3) ax.set_title("") remove_colorbars() fancy_labels(ax) plt.tight_layout() with contextlib.suppress(NameError): adjust_text( rendered_annotations, ax=ax, avoid_points=False, avoid_objects=False, avoid_self=False, autoalign="xy", ) if out: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, [ax, *ax_refs]
[docs] @save_plot_provenance def reference_scan_spatial( data: xr.DataArray, out: str | Path = "", ) -> Path | tuple[Figure, NDArray[np.object_]]: """Plots the spatial content of a dataset, useful as a quick reference. Warning: Not work correctly. (Because S.referenced_scans has been removed.) """ data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) assert isinstance(data, xr.DataArray) dims = [d for d in data.dims if d in {"cycle", "phi", "eV"}] summed_data = data.sum(dims, keep_attrs=True) fig, ax = plt.subplots(3, 2, figsize=(15, 15)) flat_axes = list(itertools.chain(*ax)) summed_data.S.plot(ax=flat_axes[0]) flat_axes[0].set_title(r"Full \textbf{eV} range") dims_except_eV = [d for d in dims if d != "eV"] summed_data = data.sum(dims_except_eV, keep_attrs=True) mul = 0.2 rng = data.coords["eV"].max().item() - data.coords["eV"].min().item() offset = data.coords["eV"].max().item() offset = min(0, offset) mul = rng / 5.0 if rng > 3 else mul # noqa: PLR2004 for i in range(5): low_e, high_e = -mul * (i + 1) + offset, -mul * i + offset title = r"\textbf{eV}" + f": {low_e:.2g} to {high_e:.2g}" summed_data.sel(eV=slice(low_e, high_e)).sum("eV", keep_attrs=True).S.plot( ax=flat_axes[i + 1], ) flat_axes[i + 1].set_title(title) y_range = flat_axes[0].get_ylim() x_range = flat_axes[0].get_xlim() delta_one_percent = ((x_range[1] - x_range[0]) / 100, (y_range[1] - y_range[0]) / 100) smart_delta: tuple[float, float] | tuple[float, float, float] = ( 2 * delta_one_percent[0], -1.5 * delta_one_percent[0], ) referenced = data.S.referenced_scans # idea here is to collect points by those that are close together, then # only plot one annotation condensed: list[tuple[float, float, list[int]]] = [] cutoff = 3 # 3 percent for index, _ in referenced.iterrows(): ff = load_data(index) x, y, _ = ff.S.sample_pos found = False for cx, cy, cl in condensed: if abs(cx - x) < cutoff * abs(delta_one_percent[0]) and abs(cy - y) < cutoff * abs( delta_one_percent[1], ): cl.append(index) found = True break if not found: condensed.append((x, y, [index])) for fax in flat_axes: for cx, cy, cl in condensed: annotate_point( fax, ( cx, cy, ), ",".join([str(_) for _ in cl]), delta=smart_delta, fontsize="large", ) plt.tight_layout() if out: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, ax