Source code for arpes.xarray_extensions.accessor.spectroscopy

from __future__ import annotations  # noqa: D100

from logging import DEBUG, INFO
from typing import (
    TYPE_CHECKING,
    Self,
    Unpack,
)

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr

from arpes.constants import TWO_DIMENSION
from arpes.correction import coords
from arpes.debug import setup_logger
from arpes.plotting.dispersion import (
    fancy_dispersion,
    hv_reference_scan,
    labeled_fermi_surface,
    reference_scan_fermi_surface,
    scan_var_reference_plot,
)
from arpes.plotting.fermi_edge import fermi_edge_reference
from arpes.plotting.spatial import reference_scan_spatial
from arpes.plotting.ui import ProfileApp
from arpes.utilities import selections
from arpes.xarray_extensions._helper.spectroscopy import mean_other_impl, sum_other_impl

from .base import ARPESAccessorBase, ARPESDataArrayAccessorBase

if TYPE_CHECKING:
    from collections.abc import Hashable, Sequence
    from pathlib import Path

    from _typeshed import Incomplete
    from matplotlib.axes import Axes
    from matplotlib.figure import Figure
    from numpy.typing import NDArray
    from panel.layout import Panel

    from arpes._typing.attrs_property import CoordsOffset
    from arpes._typing.base import ReduceMethod
    from arpes._typing.plotting import (
        HvRefScanParam,
        LabeledFermiSurfaceParam,
        MPLPlotKwargs,
        PColorMeshKwargs,
        ProfileViewParam,
    )

LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[1]
logger = setup_logger(__name__, LOGLEVEL)


[docs] @xr.register_dataarray_accessor("S") class ARPESDataArrayAccessor(ARPESDataArrayAccessorBase): """Spectrum related accessor for `xr.DataArray`.""" def __init__(self, xarray_obj: xr.DataArray) -> None: """Initialize.""" self._obj: xr.DataArray = xarray_obj assert isinstance(self._obj, xr.DataArray)
[docs] def corrected_coords( self, correction_types: CoordsOffset | Sequence[CoordsOffset], ) -> xr.DataArray: """Apply the specified coordinate corrections to the DataArray. Args: correction_types (CoordsOffset | Sequence[CoordsOffset]): The types of corrections to apply. Returns: xr.DataArray: The corrected DataArray. """ return coords.corrected_coords(self._obj, correction_types)
[docs] def correct_coords( self, correction_types: CoordsOffset | Sequence[CoordsOffset], ) -> None: """Correct the coordinates of the DataArray in place. Args: correction_types (CoordsOffset | Sequence[CoordsOffset, ...]): The types of corrections to apply. """ array = coords.corrected_coords(self._obj, correction_types) self._obj.attrs = array.attrs self._obj.coords.update(array.coords)
[docs] def sum_other( self, dim_or_dims: list[str], *, keep_attrs: bool = False, ) -> xr.DataArray: """See :meth:`ARPESDatasetAccessor.sum_other`.""" return sum_other_impl(self._obj, dim_or_dims, keep_attrs=keep_attrs)
[docs] def mean_other( self, dim_or_dims: list[str] | str, *, keep_attrs: bool = False, ) -> xr.DataArray: """See :meth:`ARPESDatasetAccessor.mean_other`.""" return mean_other_impl(self._obj, dim_or_dims, keep_attrs=keep_attrs)
[docs] def fat_sel( self, widths: dict[Hashable, float] | None = None, method: ReduceMethod = "mean", **kwargs: float, ) -> xr.DataArray: """See :meth:`ARPESDatasetAccessor.fat_sel`.""" return selections.fat_sel(data=self._obj, widths=widths, method=method, **kwargs)
# --- Mehhods about plotting # --- TODO : [RA] Consider refactoring/removing
[docs] def plot( self: Self, *args: Incomplete, **kwargs: Incomplete, ) -> None: """Utility delegate to `xr.DataArray.plot` which rasterizes`. Args: rasterized (bool): if True, rasterized (Not vector) drawing args: Pass to xr.DataArray.plot kwargs: Pass to xr.DataArray.plot """ if len(self._obj.dims) == TWO_DIMENSION: kwargs.setdefault("rasterized", True) with plt.rc_context(rc={"text.usetex": False}): self._obj.plot(*args, **kwargs)
[docs] def show(self, **kwargs: Unpack[ProfileViewParam]) -> Panel: """Show holoviews based plot.""" return ProfileApp(self._obj, **kwargs).panel()
[docs] def fs_plot( self: Self, pattern: str = "{}.png", **kwargs: Unpack[LabeledFermiSurfaceParam], ) -> Path | tuple[Figure | None, Axes]: """Provides a reference plot of the approximate Fermi surface.""" assert isinstance(self._obj, xr.DataArray) out = kwargs.get("out") if out is not None and isinstance(out, bool): out = pattern.format(f"{self.label}_fs") kwargs["out"] = out return labeled_fermi_surface(self._obj, **kwargs)
[docs] def fermi_edge_reference_plot( self: Self, pattern: str = "{}.png", out: str | Path = "", **kwargs: Unpack[MPLPlotKwargs], ) -> Path | Axes: """Provides a reference plot for a Fermi edge reference. This function generates a reference plot for a Fermi edge, which can be useful for analyzing energy spectra. It calls the `fermi_edge_reference` function and passes any additional keyword arguments to it for plotting customization. The output file name can be specified using the `out` argument, with a default name pattern. Args: pattern (str): A string pattern for the output file name. The pattern can include placeholders that will be replaced by the label or other variables. Default is "{}.png". out (str | Path): The path for saving the output figure. If set to `None` or `False`, no figure will be saved. If a boolean `True` is passed, it will use the `pattern` to generate the filename. kwargs: Additional arguments passed to the `fermi_edge_reference` function for customizing the plot. Returns: Path | Axes: The path to the saved figure (if `out` is provided), or the Axes object of the plot.Provides a reference plot for a Fermi edge reference. """ assert isinstance(self._obj, xr.DataArray) if out is not None and isinstance(out, bool): out = pattern.format(f"{self.label}_fermi_edge_reference") return fermi_edge_reference(self._obj, out=out, **kwargs)
def _referenced_scans_for_spatial_plot( self: Self, *, use_id: bool = True, pattern: str = "{}.png", out: str | Path = "", ) -> Path | tuple[Figure, NDArray[np.object_]]: """Helper function for generating a spatial plot of referenced scans. This function assists in generating a spatial plot for referenced scans, either by using a unique identifier or a predefined label. The output file name can be automatically generated or specified by the user. The function calls `reference_scan_spatial` for generating the plot and optionally saves the output figure. Args: use_id (bool): If `True`, uses the "id" attribute from the object's metadata as the label. If `False`, uses the predefined label. Default is `True`. pattern (str): A string pattern for the output file name. The placeholder `{}` will be replaced by the label or identifier. Default is `"{}.png"`. out (str | bool): The path to save the output figure. If `True`, the file name is generated using the `pattern`. If `False` or an empty string (`""`), no output is saved. Returns: Path | tuple[Figure, NDArray[np.object_]]: - If `out` is provided, returns the path to the saved figure. - Otherwise, returns the Figure and an array of the spatial data. """ label = self._obj.attrs["id"] if use_id else self.label if isinstance(out, bool) and out is True: out = pattern.format(f"{label}_reference_scan_fs") elif isinstance(out, bool) and out is False: out = "" return reference_scan_spatial(self._obj, out=out) def _referenced_scans_for_map_plot( self: Self, pattern: str = "{}.png", *, use_id: bool = True, **kwargs: Unpack[LabeledFermiSurfaceParam], ) -> Path | Axes: out = kwargs.get("out") label = self._obj.attrs["id"] if use_id else self.label if out is not None and isinstance(out, bool): out = pattern.format(f"{label}_reference_scan_fs") kwargs["out"] = out return reference_scan_fermi_surface(self._obj, **kwargs) def _referenced_scans_for_hv_map_plot( self: Self, pattern: str = "{}.png", *, use_id: bool = True, **kwargs: Unpack[HvRefScanParam], ) -> Path | Axes: out = kwargs.get("out") label = self._obj.attrs["id"] if use_id else self.label if out is not None and isinstance(out, bool): out = pattern.format(f"{label}_hv_reference_scan") out = f"{label}_hv_reference_scan.png" kwargs["out"] = out return hv_reference_scan(self._obj, **kwargs) def _simple_spectrum_reference_plot( self: Self, *, use_id: bool = True, pattern: str = "{}.png", out: str | Path = "", **kwargs: Unpack[PColorMeshKwargs], ) -> Axes | Path: assert isinstance(self._obj, xr.DataArray) label = self._obj.attrs["id"] if use_id else self.label if isinstance(out, bool): out = pattern.format(f"{label}_spectrum_reference") return fancy_dispersion(self._obj, out=out, **kwargs)
[docs] def reference_plot( self, **kwargs: Incomplete, ) -> Axes | Path | tuple[Figure, NDArray[np.object_]]: """Generates a reference plot for this piece of data according to its spectrum type. Args: kwargs: pass to referenced_scans_for_** Raises: NotImplementedError: If there is no standard approach for plotting this data. Returns: The axes which were used for plotting. """ return self.spectrum_type.reference_plot(self, **kwargs)
[docs] @xr.register_dataset_accessor("S") class ARPESDatasetAccessor(ARPESAccessorBase[xr.Dataset]): """Spectrum related accessor for `xr.Dataset`."""
[docs] def __getattr__(self, item: str) -> dict: """Forward attribute access to the spectrum, if necessary. Args: item: Attribute name Returns: The attribute after lookup on the default spectrum """ return getattr(self._obj.S.spectrum.S, item)
[docs] @property def spectrum(self) -> xr.DataArray: """Isolates a single spectrum from a dataset. This is a convenience method which is typically used in startup for tools and analysis routines which need to operate on a single piece of data. Historically, the handling of Dataset and Dataarray was a mess in previous pyarpes. Most of the current pyarpes methods/function are sufficient to treat DataArray as the main object. (The few exceptions are S.modelfit, whose return value is a Dataset, which is reasonable.) For backward compatibility, the return of load_data is still a Dataset, so in many cases, using this property for a DataArray will provide a more robust analysing environment. In practice, we filter data variables by whether they contain "spectrum" in the name before selecting the one with the largest pixel volume. This is a heuristic which tries to guarantee we select ARPES data above XPS data, if they were collected together. Returns: A spectrum found in the dataset, if one can be isolated. In the case that several candidates are found, a single spectrum is selected among the candidates. Attributes from the parent dataset are assigned onto the selected array as a convenience. Todo: Need test """ if "spectrum" in self._obj.data_vars: return self._obj.spectrum if "raw" in self._obj.data_vars: return self._obj.raw if "__xarray_dataarray_variable__" in self._obj.data_vars: return self._obj.__xarray_dataarray_variable__ candidates = self.spectra if candidates: spectrum = candidates[0] best_volume = np.prod(spectrum.shape) for c in candidates[1:]: volume = np.prod(c.shape) if volume > best_volume: spectrum = c best_volume = volume else: msg = "No spectrum found" raise RuntimeError(msg) return spectrum
[docs] @property def spectra(self) -> list[xr.DataArray]: """Collects the variables which are likely spectra. Returns: The subset of the data_vars which have dimensions indicating that they are spectra. """ return [dv for dv in self._obj.data_vars.values() if "eV" in dv.dims]
[docs] def reference_plot(self: Self, **kwargs: Incomplete) -> None: """Creates reference plots for a dataset. A bit of a misnomer because this actually makes many plots. For full datasets, the relevant components are: #. Temperature as function of scan DOF #. Photocurrent as a function of scan DOF #. Photocurrent normalized + unnormalized figures, in particular #. The reference plots for the photocurrent normalized spectrum #. The normalized total cycle intensity over scan DoF, i.e. cycle vs scan DOF integrated over E, phi #. For delay scans #. Fermi location as a function of scan DoF, integrated over phi #. Subtraction scans #. For spatial scans #. energy/angle integrated spatial maps with subsequent measurements indicated #. energy/angle integrated FS spatial maps with subsequent measurements indicated Args: kwargs: Passed to plotting routines to provide user control """ spectrum_degrees_of_freedom = set(self.spectrum.dims).intersection( {"eV", "phi", "pixel", "kx", "kp", "ky"}, ) scan_degrees_of_freedom = set(self.spectrum.dims).difference(spectrum_degrees_of_freedom) self._obj.sum(scan_degrees_of_freedom) kwargs.get("out") # <== CHECK ME the above two lines were: # make figures for temperature, photocurrent, delay make_figures_for = ["T", "IG_nA", "current", "photocurrent"] name_normalization = { "T": "T", "IG_nA": "photocurrent", "current": "photocurrent", } for figure_item in make_figures_for: if figure_item not in self._obj.data_vars: continue name = name_normalization.get(figure_item, figure_item) data_var: xr.DataArray = self._obj[figure_item] out = f"{self.label}_{name}_spec_integrated_reference.png" scan_var_reference_plot(data_var, title=f"Reference {name}", out=out) # may also want to make reference figures summing over cycle, or summing over beta # make photocurrent normalized figures normalized = self._obj / self._obj.IG_nA normalized.S.make_spectrum_reference_plots(prefix="norm_PC_", out=True) self.make_spectrum_reference_plots(out=True)
[docs] def make_spectrum_reference_plots( self, prefix: str = "", **kwargs: Incomplete, ) -> None: """Creates photocurrent normalized + unnormalized figures. Creates: #. The reference plots for the photocurrent normalized spectrum #. The normalized total cycle intensity over scan DoF, i.e. cycle vs scan DOF integrated over E, phi #. For delay scans #. Fermi location as a function of scan DoF, integrated over phi #. Subtraction scans Args: prefix: A prefix inserted into filenames to make them unique. kwargs: Passed to plotting routines to provide user control over plotting behavior """ self.spectrum.S.reference_plot(pattern=prefix + "{}.png", **kwargs) spectrum_degrees_of_freedom = set(self.spectrum.dims).intersection( {"eV", "phi", "pixel", "kx", "kp", "ky"}, ) if self.is_spatial: pass # <== CHECK ME: original is referenced = self.referenced_scans if "cycle" in self._obj.coords: integrated_over_scan = self._obj.sum(spectrum_degrees_of_freedom) integrated_over_scan.S.spectrum.S.reference_plot( pattern=prefix + "sum_spec_DoF_{}.png", **kwargs, ) if "delay" in self._obj.coords: dims = spectrum_degrees_of_freedom dims.remove("eV") angle_integrated = self._obj.sum(dims) # subtraction scan self.spectrum.S.subtraction_reference_plots(pattern=prefix + "{}.png", **kwargs) angle_integrated.S.fermi_edge_reference_plots(pattern=prefix + "{}.png", **kwargs)
def __init__(self, xarray_obj: xr.Dataset) -> None: """Initialization hook for xarray. Args: xarray_obj: The parent object which this is an accessor for Note: This class should not be called directly. """ self._obj: xr.Dataset super().__init__(xarray_obj)
[docs] def sum_other( self, dim_or_dims: list[str], *, keep_attrs: bool = False, ) -> xr.Dataset: """Calculates the sum over all dimensions *except* those specified. This is a convenience method for `xarray.Dataset.sum()` or `xarray.DataArray.sum()` that inverts the selection of dimensions. Instead of specifying dimensions to sum *along*, you specify dimensions to *keep*. Args: dim_or_dims (list[str]): A list of dimension names to keep. The sum operation will be performed over all other dimensions not in this list. keep_attrs (bool, optional): If True, attributes (`.attrs`) will be preserved on the returned object. Defaults to False. Returns: DataType: A new xarray object (Dataset or DataArray) with the data summed along the specified "other" dimensions. Its dimensions will only include those listed in `dim_or_dims`. Raises: AssertionError: If `dim_or_dims` is not a list. Examples: >>> data = xr.DataArray(np.ones((2, 3, 4)), dims=['x', 'y', 'z']) >>> accessor = ARPESAccessorBase(data) >>> accessor.sum_other(dim_or_dims=['x']) # Sums over 'y' and 'z' <xarray.DataArray (x: 2)> array([12., 12.]) Dimensions without coordinates: y, z Coordinates: * x (x) int64 0 1 >>> accessor.sum_other(dim_or_dims=['y', 'z']) # Sums over 'x' <xarray.DataArray (y: 3, z: 4)> array([[2., 2., 2., 2.], [2., 2., 2., 2.], [2., 2., 2., 2.]]) Dimensions without coordinates: x Coordinates: * y (y) int64 0 1 2 * z (z) int64 0 1 2 3 """ return sum_other_impl(self._obj, dim_or_dims, keep_attrs=keep_attrs)
[docs] def mean_other( self, dim_or_dims: list[str] | str, *, keep_attrs: bool = False, ) -> xr.Dataset: """Calculates the mean over all dimensions *except* those specified. This is a convenience method for `xarray.Dataset.mean()` or `xarray.DataArray.mean()` that inverts the selection of dimensions. Instead of specifying dimensions to average *along*, you specify dimensions to *keep*. Args: dim_or_dims (list[str] | str): A list of dimension names to keep, or a single dimension name string. The mean operation will be performed over all other dimensions not in this list/string. keep_attrs (bool, optional): If True, attributes (`.attrs`) will be preserved on the returned object. Defaults to False. Returns: DataType: A new xarray object (Dataset or DataArray) with the data averaged along the specified "other" dimensions. Its dimensions will only include those listed in `dim_or_dims`. Raises: AssertionError: If `dim_or_dims` is not a list (note: the type hint allows `str` but the assertion explicitly checks for `list`). This discrepancy should be resolved for consistency. For now, the docstring reflects the assertion. Examples: >>> data = xr.DataArray(np.arange(12).reshape(2, 2, 3), dims=['x', 'y', 'z']) >>> accessor = ARPESAccessorBase(data) >>> accessor.mean_other(dim_or_dims=['x']) # Averages over 'y' and 'z' <xarray.DataArray (x: 2)> array([2.5, 8.5]) Coordinates: * x (x) int64 0 1 >>> accessor.mean_other(dim_or_dims=['y', 'z']) # Averages over 'x' <xarray.DataArray (y: 2, z: 3)> array([[2.5, 3.5, 4.5], [5.5, 6.5, 7.5]]) Coordinates: * y (y) int64 0 1 * z (z) int64 1 2 3 """ return mean_other_impl(self._obj, dim_or_dims, keep_attrs=keep_attrs)
[docs] def fat_sel( self, widths: dict[Hashable, float] | None = None, method: ReduceMethod = "mean", **kwargs: float, ) -> xr.Dataset: """Performs a 'fat' selection, integrating data over small regions specified by widths. This method allows for integrating a selection over a small coordinate region (defined by `widths` or keyword arguments), effectively reducing noise by averaging or summing over nearby slices. The resulting dataset will be normalized by the number of slices integrated over if `method="mean"`. Args: widths (dict[Hashable, float] | None, optional): A dictionary specifying the width of the integration window for each dimension. Keys are dimension names (Hashable), and values are float widths. Overrides any widths specified in `kwargs`. Defaults to None, in which case `selections.fat_sel` might use default widths. method (ReduceMethod, optional): The reduction method to apply within the selection window. Can be "mean" (default) or "sum". **kwargs (float): Keyword arguments that can define specific selection points (e.g., `eV=1.5`) or widths (e.g., `eV_width=0.1`). **Note**: Using `*_width` in kwargs for specifying widths is deprecated. Prefer the `widths` dictionary argument. Returns: XrTypes: The xarray.DataArray or xarray.Dataset after the 'fat' selection and reduction have been applied. The dimensions for which a width was specified will effectively be reduced or removed. Note: The `widths` argument is the preferred way to specify integration widths. Using `*_width` through `kwargs` is deprecated and may be removed in future versions. """ return selections.fat_sel(data=self._obj, widths=widths, method=method, **kwargs)