from __future__ import annotations # noqa: D100
from logging import DEBUG, INFO
from typing import (
TYPE_CHECKING,
Self,
Unpack,
)
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from arpes.constants import TWO_DIMENSION
from arpes.correction import coords
from arpes.debug import setup_logger
from arpes.plotting.dispersion import (
fancy_dispersion,
hv_reference_scan,
labeled_fermi_surface,
reference_scan_fermi_surface,
scan_var_reference_plot,
)
from arpes.plotting.fermi_edge import fermi_edge_reference
from arpes.plotting.spatial import reference_scan_spatial
from arpes.plotting.ui import ProfileApp
from arpes.utilities import selections
from arpes.xarray_extensions._helper.spectroscopy import mean_other_impl, sum_other_impl
from .base import ARPESAccessorBase, ARPESDataArrayAccessorBase
if TYPE_CHECKING:
from collections.abc import Hashable, Sequence
from pathlib import Path
from _typeshed import Incomplete
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from numpy.typing import NDArray
from panel.layout import Panel
from arpes._typing.attrs_property import CoordsOffset
from arpes._typing.base import ReduceMethod
from arpes._typing.plotting import (
HvRefScanParam,
LabeledFermiSurfaceParam,
MPLPlotKwargs,
PColorMeshKwargs,
ProfileViewParam,
)
LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[1]
logger = setup_logger(__name__, LOGLEVEL)
[docs]
@xr.register_dataarray_accessor("S")
class ARPESDataArrayAccessor(ARPESDataArrayAccessorBase):
"""Spectrum related accessor for `xr.DataArray`."""
def __init__(self, xarray_obj: xr.DataArray) -> None:
"""Initialize."""
self._obj: xr.DataArray = xarray_obj
assert isinstance(self._obj, xr.DataArray)
[docs]
def corrected_coords(
self,
correction_types: CoordsOffset | Sequence[CoordsOffset],
) -> xr.DataArray:
"""Apply the specified coordinate corrections to the DataArray.
Args:
correction_types (CoordsOffset | Sequence[CoordsOffset]): The types of corrections to
apply.
Returns:
xr.DataArray: The corrected DataArray.
"""
return coords.corrected_coords(self._obj, correction_types)
[docs]
def correct_coords(
self,
correction_types: CoordsOffset | Sequence[CoordsOffset],
) -> None:
"""Correct the coordinates of the DataArray in place.
Args:
correction_types (CoordsOffset | Sequence[CoordsOffset, ...]): The types of corrections
to apply.
"""
array = coords.corrected_coords(self._obj, correction_types)
self._obj.attrs = array.attrs
self._obj.coords.update(array.coords)
[docs]
def sum_other(
self,
dim_or_dims: list[str],
*,
keep_attrs: bool = False,
) -> xr.DataArray:
"""See :meth:`ARPESDatasetAccessor.sum_other`."""
return sum_other_impl(self._obj, dim_or_dims, keep_attrs=keep_attrs)
[docs]
def mean_other(
self,
dim_or_dims: list[str] | str,
*,
keep_attrs: bool = False,
) -> xr.DataArray:
"""See :meth:`ARPESDatasetAccessor.mean_other`."""
return mean_other_impl(self._obj, dim_or_dims, keep_attrs=keep_attrs)
[docs]
def fat_sel(
self,
widths: dict[Hashable, float] | None = None,
method: ReduceMethod = "mean",
**kwargs: float,
) -> xr.DataArray:
"""See :meth:`ARPESDatasetAccessor.fat_sel`."""
return selections.fat_sel(data=self._obj, widths=widths, method=method, **kwargs)
# --- Mehhods about plotting
# --- TODO : [RA] Consider refactoring/removing
[docs]
def plot(
self: Self,
*args: Incomplete,
**kwargs: Incomplete,
) -> None:
"""Utility delegate to `xr.DataArray.plot` which rasterizes`.
Args:
rasterized (bool): if True, rasterized (Not vector) drawing
args: Pass to xr.DataArray.plot
kwargs: Pass to xr.DataArray.plot
"""
if len(self._obj.dims) == TWO_DIMENSION:
kwargs.setdefault("rasterized", True)
with plt.rc_context(rc={"text.usetex": False}):
self._obj.plot(*args, **kwargs)
[docs]
def show(self, **kwargs: Unpack[ProfileViewParam]) -> Panel:
"""Show holoviews based plot."""
return ProfileApp(self._obj, **kwargs).panel()
[docs]
def fs_plot(
self: Self,
pattern: str = "{}.png",
**kwargs: Unpack[LabeledFermiSurfaceParam],
) -> Path | tuple[Figure | None, Axes]:
"""Provides a reference plot of the approximate Fermi surface."""
assert isinstance(self._obj, xr.DataArray)
out = kwargs.get("out")
if out is not None and isinstance(out, bool):
out = pattern.format(f"{self.label}_fs")
kwargs["out"] = out
return labeled_fermi_surface(self._obj, **kwargs)
[docs]
def fermi_edge_reference_plot(
self: Self,
pattern: str = "{}.png",
out: str | Path = "",
**kwargs: Unpack[MPLPlotKwargs],
) -> Path | Axes:
"""Provides a reference plot for a Fermi edge reference.
This function generates a reference plot for a Fermi edge, which can be useful for analyzing
energy spectra. It calls the `fermi_edge_reference` function and passes any additional
keyword arguments to it for plotting customization. The output file name can be specified
using the `out` argument, with a default name pattern.
Args:
pattern (str): A string pattern for the output file name. The pattern can include
placeholders that will be replaced by the label or other variables.
Default is "{}.png".
out (str | Path): The path for saving the output figure. If set to `None` or `False`,
no figure will be saved. If a boolean `True` is passed, it will use the `pattern`
to generate the filename.
kwargs: Additional arguments passed to the `fermi_edge_reference` function for
customizing the plot.
Returns:
Path | Axes: The path to the saved figure (if `out` is provided), or the Axes object of
the plot.Provides a reference plot for a Fermi edge reference.
"""
assert isinstance(self._obj, xr.DataArray)
if out is not None and isinstance(out, bool):
out = pattern.format(f"{self.label}_fermi_edge_reference")
return fermi_edge_reference(self._obj, out=out, **kwargs)
def _referenced_scans_for_spatial_plot(
self: Self,
*,
use_id: bool = True,
pattern: str = "{}.png",
out: str | Path = "",
) -> Path | tuple[Figure, NDArray[np.object_]]:
"""Helper function for generating a spatial plot of referenced scans.
This function assists in generating a spatial plot for referenced scans, either by using a
unique identifier or a predefined label. The output file name can be automatically generated
or specified by the user. The function calls `reference_scan_spatial` for generating the
plot and optionally saves the output figure.
Args:
use_id (bool): If `True`, uses the "id" attribute from the object's metadata as the
label. If `False`, uses the predefined label. Default is `True`.
pattern (str): A string pattern for the output file name. The placeholder `{}` will be
replaced by the label or identifier. Default is `"{}.png"`.
out (str | bool): The path to save the output figure. If `True`, the file name is
generated using the `pattern`. If `False` or an empty string (`""`), no output is
saved.
Returns:
Path | tuple[Figure, NDArray[np.object_]]:
- If `out` is provided, returns the path to the saved figure.
- Otherwise, returns the Figure and an array of the spatial data.
"""
label = self._obj.attrs["id"] if use_id else self.label
if isinstance(out, bool) and out is True:
out = pattern.format(f"{label}_reference_scan_fs")
elif isinstance(out, bool) and out is False:
out = ""
return reference_scan_spatial(self._obj, out=out)
def _referenced_scans_for_map_plot(
self: Self,
pattern: str = "{}.png",
*,
use_id: bool = True,
**kwargs: Unpack[LabeledFermiSurfaceParam],
) -> Path | Axes:
out = kwargs.get("out")
label = self._obj.attrs["id"] if use_id else self.label
if out is not None and isinstance(out, bool):
out = pattern.format(f"{label}_reference_scan_fs")
kwargs["out"] = out
return reference_scan_fermi_surface(self._obj, **kwargs)
def _referenced_scans_for_hv_map_plot(
self: Self,
pattern: str = "{}.png",
*,
use_id: bool = True,
**kwargs: Unpack[HvRefScanParam],
) -> Path | Axes:
out = kwargs.get("out")
label = self._obj.attrs["id"] if use_id else self.label
if out is not None and isinstance(out, bool):
out = pattern.format(f"{label}_hv_reference_scan")
out = f"{label}_hv_reference_scan.png"
kwargs["out"] = out
return hv_reference_scan(self._obj, **kwargs)
def _simple_spectrum_reference_plot(
self: Self,
*,
use_id: bool = True,
pattern: str = "{}.png",
out: str | Path = "",
**kwargs: Unpack[PColorMeshKwargs],
) -> Axes | Path:
assert isinstance(self._obj, xr.DataArray)
label = self._obj.attrs["id"] if use_id else self.label
if isinstance(out, bool):
out = pattern.format(f"{label}_spectrum_reference")
return fancy_dispersion(self._obj, out=out, **kwargs)
[docs]
def reference_plot(
self,
**kwargs: Incomplete,
) -> Axes | Path | tuple[Figure, NDArray[np.object_]]:
"""Generates a reference plot for this piece of data according to its spectrum type.
Args:
kwargs: pass to referenced_scans_for_**
Raises:
NotImplementedError: If there is no standard approach for plotting this data.
Returns:
The axes which were used for plotting.
"""
return self.spectrum_type.reference_plot(self, **kwargs)
[docs]
@xr.register_dataset_accessor("S")
class ARPESDatasetAccessor(ARPESAccessorBase[xr.Dataset]):
"""Spectrum related accessor for `xr.Dataset`."""
[docs]
def __getattr__(self, item: str) -> dict:
"""Forward attribute access to the spectrum, if necessary.
Args:
item: Attribute name
Returns:
The attribute after lookup on the default spectrum
"""
return getattr(self._obj.S.spectrum.S, item)
[docs]
@property
def spectrum(self) -> xr.DataArray:
"""Isolates a single spectrum from a dataset.
This is a convenience method which is typically used in startup for
tools and analysis routines which need to operate on a single
piece of data.
Historically, the handling of Dataset and Dataarray was a mess in previous pyarpes.
Most of the current pyarpes methods/function are sufficient to treat DataArray as the main
object. (The few exceptions are S.modelfit, whose return value is a Dataset, which is
reasonable.) For backward compatibility, the return of load_data is still a Dataset,
so in many cases, using this property for a DataArray will provide a more robust analysing
environment.
In practice, we filter data variables by whether they contain "spectrum"
in the name before selecting the one with the largest pixel volume.
This is a heuristic which tries to guarantee we select ARPES data
above XPS data, if they were collected together.
Returns:
A spectrum found in the dataset, if one can be isolated.
In the case that several candidates are found, a single spectrum
is selected among the candidates.
Attributes from the parent dataset are assigned onto the selected
array as a convenience.
Todo: Need test
"""
if "spectrum" in self._obj.data_vars:
return self._obj.spectrum
if "raw" in self._obj.data_vars:
return self._obj.raw
if "__xarray_dataarray_variable__" in self._obj.data_vars:
return self._obj.__xarray_dataarray_variable__
candidates = self.spectra
if candidates:
spectrum = candidates[0]
best_volume = np.prod(spectrum.shape)
for c in candidates[1:]:
volume = np.prod(c.shape)
if volume > best_volume:
spectrum = c
best_volume = volume
else:
msg = "No spectrum found"
raise RuntimeError(msg)
return spectrum
[docs]
@property
def spectra(self) -> list[xr.DataArray]:
"""Collects the variables which are likely spectra.
Returns:
The subset of the data_vars which have dimensions indicating
that they are spectra.
"""
return [dv for dv in self._obj.data_vars.values() if "eV" in dv.dims]
[docs]
def reference_plot(self: Self, **kwargs: Incomplete) -> None:
"""Creates reference plots for a dataset.
A bit of a misnomer because this actually makes many plots. For full datasets,
the relevant components are:
#. Temperature as function of scan DOF
#. Photocurrent as a function of scan DOF
#. Photocurrent normalized + unnormalized figures, in particular
#. The reference plots for the photocurrent normalized spectrum
#. The normalized total cycle intensity over scan DoF, i.e. cycle vs scan DOF integrated
over E, phi
#. For delay scans
#. Fermi location as a function of scan DoF, integrated over phi
#. Subtraction scans
#. For spatial scans
#. energy/angle integrated spatial maps with subsequent measurements indicated
#. energy/angle integrated FS spatial maps with subsequent measurements indicated
Args:
kwargs: Passed to plotting routines to provide user control
"""
spectrum_degrees_of_freedom = set(self.spectrum.dims).intersection(
{"eV", "phi", "pixel", "kx", "kp", "ky"},
)
scan_degrees_of_freedom = set(self.spectrum.dims).difference(spectrum_degrees_of_freedom)
self._obj.sum(scan_degrees_of_freedom)
kwargs.get("out")
# <== CHECK ME the above two lines were:
# make figures for temperature, photocurrent, delay
make_figures_for = ["T", "IG_nA", "current", "photocurrent"]
name_normalization = {
"T": "T",
"IG_nA": "photocurrent",
"current": "photocurrent",
}
for figure_item in make_figures_for:
if figure_item not in self._obj.data_vars:
continue
name = name_normalization.get(figure_item, figure_item)
data_var: xr.DataArray = self._obj[figure_item]
out = f"{self.label}_{name}_spec_integrated_reference.png"
scan_var_reference_plot(data_var, title=f"Reference {name}", out=out)
# may also want to make reference figures summing over cycle, or summing over beta
# make photocurrent normalized figures
normalized = self._obj / self._obj.IG_nA
normalized.S.make_spectrum_reference_plots(prefix="norm_PC_", out=True)
self.make_spectrum_reference_plots(out=True)
[docs]
def make_spectrum_reference_plots(
self,
prefix: str = "",
**kwargs: Incomplete,
) -> None:
"""Creates photocurrent normalized + unnormalized figures.
Creates:
#. The reference plots for the photocurrent normalized spectrum
#. The normalized total cycle intensity over scan DoF,
i.e. cycle vs scan DOF integrated over E, phi
#. For delay scans
#. Fermi location as a function of scan DoF, integrated over phi
#. Subtraction scans
Args:
prefix: A prefix inserted into filenames to make them unique.
kwargs: Passed to plotting routines to provide user control over plotting
behavior
"""
self.spectrum.S.reference_plot(pattern=prefix + "{}.png", **kwargs)
spectrum_degrees_of_freedom = set(self.spectrum.dims).intersection(
{"eV", "phi", "pixel", "kx", "kp", "ky"},
)
if self.is_spatial:
pass
# <== CHECK ME: original is referenced = self.referenced_scans
if "cycle" in self._obj.coords:
integrated_over_scan = self._obj.sum(spectrum_degrees_of_freedom)
integrated_over_scan.S.spectrum.S.reference_plot(
pattern=prefix + "sum_spec_DoF_{}.png",
**kwargs,
)
if "delay" in self._obj.coords:
dims = spectrum_degrees_of_freedom
dims.remove("eV")
angle_integrated = self._obj.sum(dims)
# subtraction scan
self.spectrum.S.subtraction_reference_plots(pattern=prefix + "{}.png", **kwargs)
angle_integrated.S.fermi_edge_reference_plots(pattern=prefix + "{}.png", **kwargs)
def __init__(self, xarray_obj: xr.Dataset) -> None:
"""Initialization hook for xarray.
Args:
xarray_obj: The parent object which this is an accessor for
Note:
This class should not be called directly.
"""
self._obj: xr.Dataset
super().__init__(xarray_obj)
[docs]
def sum_other(
self,
dim_or_dims: list[str],
*,
keep_attrs: bool = False,
) -> xr.Dataset:
"""Calculates the sum over all dimensions *except* those specified.
This is a convenience method for `xarray.Dataset.sum()` or `xarray.DataArray.sum()`
that inverts the selection of dimensions. Instead of specifying dimensions
to sum *along*, you specify dimensions to *keep*.
Args:
dim_or_dims (list[str]): A list of dimension names to keep. The sum
operation will be performed over all other dimensions not in this list.
keep_attrs (bool, optional): If True, attributes (`.attrs`) will be
preserved on the returned object. Defaults to False.
Returns:
DataType: A new xarray object (Dataset or DataArray) with the data
summed along the specified "other" dimensions. Its dimensions will
only include those listed in `dim_or_dims`.
Raises:
AssertionError: If `dim_or_dims` is not a list.
Examples:
>>> data = xr.DataArray(np.ones((2, 3, 4)), dims=['x', 'y', 'z'])
>>> accessor = ARPESAccessorBase(data)
>>> accessor.sum_other(dim_or_dims=['x']) # Sums over 'y' and 'z'
<xarray.DataArray (x: 2)>
array([12., 12.])
Dimensions without coordinates: y, z
Coordinates:
* x (x) int64 0 1
>>> accessor.sum_other(dim_or_dims=['y', 'z']) # Sums over 'x'
<xarray.DataArray (y: 3, z: 4)>
array([[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]])
Dimensions without coordinates: x
Coordinates:
* y (y) int64 0 1 2
* z (z) int64 0 1 2 3
"""
return sum_other_impl(self._obj, dim_or_dims, keep_attrs=keep_attrs)
[docs]
def mean_other(
self,
dim_or_dims: list[str] | str,
*,
keep_attrs: bool = False,
) -> xr.Dataset:
"""Calculates the mean over all dimensions *except* those specified.
This is a convenience method for `xarray.Dataset.mean()` or `xarray.DataArray.mean()`
that inverts the selection of dimensions. Instead of specifying dimensions
to average *along*, you specify dimensions to *keep*.
Args:
dim_or_dims (list[str] | str): A list of dimension names to keep, or a single
dimension name string. The mean operation will be performed over all other
dimensions not in this list/string.
keep_attrs (bool, optional): If True, attributes (`.attrs`) will be
preserved on the returned object. Defaults to False.
Returns:
DataType: A new xarray object (Dataset or DataArray) with the data
averaged along the specified "other" dimensions. Its dimensions will
only include those listed in `dim_or_dims`.
Raises:
AssertionError: If `dim_or_dims` is not a list (note: the type hint allows `str`
but the assertion explicitly checks for `list`). This discrepancy should
be resolved for consistency. For now, the docstring reflects the assertion.
Examples:
>>> data = xr.DataArray(np.arange(12).reshape(2, 2, 3), dims=['x', 'y', 'z'])
>>> accessor = ARPESAccessorBase(data)
>>> accessor.mean_other(dim_or_dims=['x']) # Averages over 'y' and 'z'
<xarray.DataArray (x: 2)>
array([2.5, 8.5])
Coordinates:
* x (x) int64 0 1
>>> accessor.mean_other(dim_or_dims=['y', 'z']) # Averages over 'x'
<xarray.DataArray (y: 2, z: 3)>
array([[2.5, 3.5, 4.5],
[5.5, 6.5, 7.5]])
Coordinates:
* y (y) int64 0 1
* z (z) int64 1 2 3
"""
return mean_other_impl(self._obj, dim_or_dims, keep_attrs=keep_attrs)
[docs]
def fat_sel(
self,
widths: dict[Hashable, float] | None = None,
method: ReduceMethod = "mean",
**kwargs: float,
) -> xr.Dataset:
"""Performs a 'fat' selection, integrating data over small regions specified by widths.
This method allows for integrating a selection over a small coordinate region
(defined by `widths` or keyword arguments), effectively reducing noise
by averaging or summing over nearby slices. The resulting dataset will
be normalized by the number of slices integrated over if `method="mean"`.
Args:
widths (dict[Hashable, float] | None, optional): A dictionary
specifying the width of the integration window for each dimension.
Keys are dimension names (Hashable), and values are float widths.
Overrides any widths specified in `kwargs`. Defaults to None,
in which case `selections.fat_sel` might use default widths.
method (ReduceMethod, optional): The reduction method to apply within
the selection window. Can be "mean" (default) or "sum".
**kwargs (float): Keyword arguments that can define specific selection
points (e.g., `eV=1.5`) or widths (e.g., `eV_width=0.1`).
**Note**: Using `*_width` in kwargs for specifying widths is
deprecated. Prefer the `widths` dictionary argument.
Returns:
XrTypes: The xarray.DataArray or xarray.Dataset after the 'fat'
selection and reduction have been applied. The dimensions for which
a width was specified will effectively be reduced or removed.
Note:
The `widths` argument is the preferred way to specify integration
widths. Using `*_width` through `kwargs` is deprecated and may
be removed in future versions.
"""
return selections.fat_sel(data=self._obj, widths=widths, method=method, **kwargs)