"""This module provides functions to manipulate coordinates in xarray DataArrays."""
from __future__ import annotations
import warnings
from collections import Counter
from logging import DEBUG, INFO
from typing import TYPE_CHECKING, LiteralString, get_args
import numpy as np
import xarray as xr
from arpes._typing.attrs_property import CoordsOffset
from arpes.debug import setup_logger
if TYPE_CHECKING:
from collections.abc import Hashable, Mapping, Sequence
from _typeshed import Incomplete
from numpy.typing import NDArray
from arpes.provenance import Provenance
LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[1]
logger = setup_logger(__name__, LOGLEVEL)
__all__ = (
"adjust_coords_to_limit",
"corrected_coords",
"extend_coords",
"is_equally_spaced",
"shift_by",
)
def adjust_coords_to_limit(
da: xr.DataArray,
new_limits: Mapping[Hashable, float],
) -> dict[Hashable, NDArray[np.floating]]:
"""Extend the coordinates of an xarray DataArray to given values for each dimension.
The extension will ensure that the new coordinates cover up to the given extension value,
and only the newly added coordinates will be returned.
Parameters:
da : xr.DataArray
The original DataArray with equidistant coordinates.
extensions : dict
A dictionary specifying the values to which each coordinate should be extended.
Example: {"x": 5, "y": -1}
Returns:
dict: A dictionary with the new extended coordinates for each dimension.
Only the newly added coordinates are returned, which will be used in stretch_coords.
"""
new_coords_dict = {}
for dim, new_limit in new_limits.items():
coords = da.coords[dim].values
diffs = np.diff(coords)
step = np.median(diffs)
min_coord = np.min(coords)
max_coord = np.max(coords)
if new_limit > max_coord:
new_coords = np.arange(max_coord + step, new_limit + step, step)
elif new_limit < min_coord:
new_coords = np.arange(new_limit, min_coord, step)
else:
new_coords = np.array([])
new_coords_dict[dim] = new_coords
return new_coords_dict
def extend_coords(
da: xr.DataArray,
new_coords: Mapping[Hashable, list[float] | NDArray[np.floating]],
) -> xr.DataArray:
"""Expand the coordinates of an xarray DataArray by adding new coordinate values.
The new values will be filled with NaN.
Parameters:
da : xr.DataArray
The original DataArray.
new_coords : dict
Dictionary where keys are coordinate names and values are lists of new coordinate values.
If no new coordinates are specified, existing coordinates are retained.
Returns:
xr.DataArray: A new DataArray with expanded coordinates and NaN-filled missing values.
"""
stretch_coords = {dim: da.coords[dim].values for dim in da.dims}
for dim, values in new_coords.items():
stretch_coords[dim] = np.union1d(stretch_coords.get(dim, []), values)
shape = [len(stretch_coords[dim]) for dim in da.dims]
coords = da.coords.copy()
coords.update(stretch_coords)
padding_value = 0 if da.dtype == np.int_ else np.nan
expanded_da = xr.DataArray(
np.full(shape, padding_value, dtype=np.float64),
coords=coords,
dims=list(da.dims),
attrs=da.attrs,
)
expanded_da.loc[{dim: da.coords[dim] for dim in da.dims}] = da.astype(np.float64)
return expanded_da
def is_equally_spaced(
coords: xr.DataArray | NDArray[np.floating],
dim_name: Hashable | None = None,
**kwargs: Incomplete,
) -> float:
"""Helper function to check the spacing is equal.
If not, the most frequent space is returned with warning message.
Args:
coords (xr.DataArray): xr.DataArray coords to be checked.
dim_name (str): dimension name.
**kwargs: kwargs for np.allclose (atol, rtol, equal_nan, ...)
Returns:
float: the value of spacing.
"""
diffs = np.diff(coords)
if np.allclose(diffs, diffs[0], **kwargs):
return diffs[0]
most_common, _ = Counter(diffs).most_common(1)[0]
msg = f"Coordinate {dim_name} is not perfectly equally spaced. "
msg += f"Use the most common interval {most_common}."
warnings.warn(msg, UserWarning, stacklevel=2)
return most_common
[docs]
def shift_by(
data: xr.DataArray,
coord_name: str,
shift_value: float,
) -> xr.DataArray:
"""Shifts the coordinates by the specified values.
Args:
data (xr.DataArray): The DataArray to shift.
coord_name (str): The coordinate name to shift.
shift_value (float): The amount of the shift.
Returns:
xr.DataArray: The DataArray with shifted coordinates.
"""
assert isinstance(data, xr.DataArray)
assert coord_name in data.coords
shifted_coords = {coord_name: data.coords[coord_name] + shift_value}
shifted_data = data.assign_coords(**shifted_coords)
provenance_: Provenance = shifted_data.attrs.get("provenance", {})
provenance_shift_coords = provenance_.get("shift_coords", [])
provenance_shift_coords.append((coord_name, shift_value))
provenance_["shift_coords"] = provenance_shift_coords
shifted_data.attrs["provenance"] = provenance_
return shifted_data
def corrected_coords(
data: xr.DataArray,
correction_types: CoordsOffset | Sequence[CoordsOffset],
) -> xr.DataArray:
"""Corrects the coordinates of the given data by applying necessary transformations.
Args:
data (xr.DataArray): The input ARPES data array with coordinates to be corrected.
correction_types (CoordsOffset | tuple[CoordsOffset]): Correction types to be applied to the
data.
Returns:
xr.DataArray: The data array with corrected coordinates.
"""
if isinstance(correction_types, str):
correction_types = (correction_types,)
corrected_data = data.copy(deep=True)
for correction_type in correction_types:
assert correction_type in get_args(CoordsOffset)
if "_offset" in correction_type:
coord_name: LiteralString = correction_type.split("_offset")[0]
if coord_name not in corrected_data.coords:
warnings.warn(
f"{coord_name} has not been set, while you correct "
f"{coord_name} by {correction_type}.",
stacklevel=2,
)
shift_value = (
-corrected_data.attrs[correction_type]
if coord_name in data.dims
else corrected_data.attrs[correction_type]
)
corrected_data = shift_by(corrected_data, coord_name, shift_value)
if coord_name in corrected_data.attrs:
corrected_data.attrs[coord_name] -= corrected_data.attrs[correction_type]
# angle correction by beta or theta
elif correction_type in {"beta", "theta"}:
corrected_data = _apply_beta_theta_offset(corrected_data, correction_type)
corrected_data.attrs[correction_type] = 0
# provenance
provenance_: Provenance = corrected_data.attrs.get("provenance", {})
provenance_corrected_cords: list[CoordsOffset] = provenance_.get("coords_correction", [])
provenance_corrected_cords.append(correction_type)
provenance_["coords_correction"] = provenance_corrected_cords
corrected_data.attrs["provenance"] = provenance_
return corrected_data
def _apply_beta_theta_offset(
data: xr.DataArray,
correction_type: str,
) -> xr.DataArray:
assert correction_type in {"beta", "theta"}
axis = "psi" if data.S.is_slit_vertical else "phi"
if correction_type == "beta":
axis = "phi" if data.S.is_slit_vertical else "psi"
data = shift_by(data, axis, data.attrs.get(correction_type, 0))
data.attrs[correction_type] = 0
data.coords[correction_type] = 0
return data