Source code for arpes.plotting.bz

"""Utilities related to plotting Brillouin zones and data onto them."""

from __future__ import annotations

import warnings
from logging import DEBUG, INFO
from typing import TYPE_CHECKING, TypeAlias

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from ase.dft.bz import bz_plot, bz_vertices
from matplotlib.axes import Axes
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from scipy.spatial.transform import Rotation

from arpes.analysis.mask import apply_mask_to_coords
from arpes.constants import TWO_DIMENSION
from arpes.debug import setup_logger
from arpes.utilities.bz import build_2dbz_poly, process_kpath
from arpes.utilities.geometry import polyhedron_intersect_plane

from .utils import path_for_plot

if TYPE_CHECKING:
    from collections.abc import Callable, Sequence
    from pathlib import Path

    from _typeshed import Incomplete
    from ase.cell import Cell
    from matplotlib.figure import Figure
    from matplotlib.typing import ColorType
    from mpl_toolkits.mplot3d import Axes3D
    from numpy.typing import NDArray


__all__ = (
    "bz2d_segments",
    "overplot_standard",
    "plot_data_to_bz",
    "plot_data_to_bz2d",
)

LOGLEVEL = (DEBUG, INFO)[1]
logger = setup_logger(__name__, LOGLEVEL)


class Translation:
    """Base translation class, meant to provide some extension over rotations.

    Rotations are available from `scipy.spatial.transform.Rotation`.
    """

    def __init__(self, translation_vector: Sequence[float]) -> None:
        self.dim = len(translation_vector)
        self.translation_vector: NDArray[np.float64] = np.asarray(translation_vector)

    def apply(self, vectors: NDArray[np.float64]) -> NDArray[np.float64]:
        """Applies the translation to a set of vectors.

        If this transform is D-dimensional (for D=2,3) and is applied to a different
        dimensional set of vectors, a ValueError will be thrown due to the dimension
        mismatch.

        ```
        self.apply(self.apply(vectors)) == vectors
        ```

        Args:
            vectors: array_like with shape (2 or 3,) or (N, 2 or 3)
        """
        vectors = np.asarray(vectors)

        if vectors.ndim > TWO_DIMENSION or vectors.shape[-1] not in {2, 3}:
            msg = "Expected a 2D or 3D vector (2 or 3,)"
            msg += f" of list of vectors (N, 2 or 3,), instead receivied: {vectors.shape}"
            raise ValueError(
                msg,
            )

        return vectors + self.translation_vector


Transformation: TypeAlias = Rotation | Translation


def segments_standard(
    cell: Cell,
    transformations: list | None = None,
) -> tuple[list[NDArray[np.float64]], list[NDArray[np.float64]]]:
    return bz2d_segments(cell, transformations)


[docs] def overplot_standard( cell: Cell, repeat: tuple[int, int, int] | tuple[int, int] = (1, 1, 1), transforms: list | None = None, ) -> Callable[[Axes], Axes]: """A higher order function to plot a Brillouin zone over a plot. Args: cell (Cell): ASE Cell object for BZ drawing. repeat (tuple[int, int, int]): Set the repeating draw of BZ. default is (1, 1, 1), no repeat. transforms: List of linear transformation (scipy.spatial.transform.Rotation) Returns: Axes: """ if transforms is None: transforms = [Rotation.from_rotvec([0, 0, 0])] logger.debug(f"transforms: {transforms}") def overplot_the_bz(ax: Axes) -> Axes: ax = bz_plot( cell=cell, ax=ax, paths=[], repeat=repeat, transforms=transforms, zorder=5, ) ax.set_axis_on() return ax return overplot_the_bz
def apply_transformations( points: NDArray[np.float64], transformations: list[Transformation] | None = None, ) -> NDArray[np.float64]: """Applies a series of transformations to a sequence of vectors or a single vector. Args: points: point coordinate transformations: list of Transformation (Translation / Rotation) Returns: The collection of transformed points. """ transformations = transformations if transformations is not None else [] for transformation in transformations: points = transformation.apply(points) return points
[docs] def plot_plane_to_bz( cell: Cell, plane: str | list[NDArray[np.float64]], ax: Axes3D, facecolor: ColorType = "red", ) -> None: """Plots a 2D cut plane onto a Brillouin zone. Args: cell (Cell): ASE cell object plane: [TODO:description] ax: [TODO:description] special_points: [TODO:description] facecolor: [TODO:description] """ warnings.warn( "This method will be deprecated.", category=DeprecationWarning, stacklevel=2, ) if isinstance(plane, str): plane_points: list[NDArray[np.float64]] = process_kpath( plane, cell, )[0] else: plane_points = plane d1, d2 = plane_points[1] - plane_points[0], plane_points[2] - plane_points[0] faces = [p[0] for p in bz_vertices(np.linalg.inv(cell).T)] pts = polyhedron_intersect_plane(faces, np.cross(d1, d2), plane_points[0]) collection = Poly3DCollection([pts]) collection.set_facecolor(facecolor) ax.add_collection3d(collection, zs=0, zdir="z")
[docs] def plot_data_to_bz( data: xr.DataArray, cell: Cell, **kwargs: Incomplete, ) -> Path | tuple[Figure | None, Axes]: """A dimension agnostic tool used to plot ARPES data onto a Brillouin zone.""" if len(data) == TWO_DIMENSION + 1: raise NotImplementedError return plot_data_to_bz2d(data, cell, **kwargs)
def plot_data_to_bz2d( # noqa: PLR0913 data_array: xr.DataArray, cell: Cell, rotate: float | None = None, shift: NDArray[np.float64] | None = None, scale: float | None = None, ax: Axes | None = None, out: str | Path = "", bz_number: Sequence[float] | None = None, *, mask: bool = True, **kwargs: Incomplete, ) -> Path | tuple[Figure | None, Axes]: """Plots data onto the 2D Brillouin zone. Args: data_array: Data to plot cell(Cell): ASE Cell object (Real space) rotate: [TODO:description] shift: [TODO:description] scale: [TODO:description] ax (Axes): [TODO:description] out: [TODO:description] bz_number: [TODO:description] mask: [TODO:description] kwargs: [TODO:description] Returns: [TODO:description] """ assert data_array.S.is_kspace, "You must k-space convert data before plotting to BZs" assert isinstance(data_array, xr.DataArray), "data_array must be xr.DataArray, not Dataset" if bz_number is None: bz_number = (0, 0) fig = None if ax is None: fig, ax = plt.subplots(figsize=(9, 9)) bz_plot(cell, paths="all", ax=ax) assert isinstance(ax, Axes) icell = cell.reciprocal() # Prep coordinates and mask raveled = data_array.G.meshgrid(as_dataset=True) dims = data_array.dims if rotate is not None: c, s = np.cos(rotate), np.sin(rotate) rotation = np.array([(c, -s), (s, c)]) raveled = raveled.G.transform_meshgrid(dims, rotation) if scale is not None: raveled = raveled.G.scale_meshgrid(dims, scale) if shift is not None: raveled = raveled.G.shift_meshgrid(dims, shift) copied = data_array.values.copy() if mask: built_mask = apply_mask_to_coords(raveled, build_2dbz_poly(cell=cell), dims) copied[built_mask.T] = np.nan cmap = kwargs.get("cmap", mpl.colormaps["Blues"]) if isinstance(cmap, str): cmap = mpl.colormaps.get_cmap(cmap) cmap.set_bad((1, 1, 1, 0)) delta_x = np.dot(np.array(bz_number), np.array(icell)[:2, 0]) delta_y = np.dot(np.array(bz_number), np.array(icell)[:2, 1]) ax.pcolormesh( raveled.data_vars[dims[0]].values + delta_x, raveled.data_vars[dims[1]].values + delta_y, copied.T, cmap=cmap, ) if out: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) return fig, ax def bz2d_segments( cell: Cell, transformations: list[Transformation] | None = None, ) -> tuple[list[NDArray[np.float64]], list[NDArray[np.float64]]]: """Calculates the line segments corresponding to a 2D BZ.""" segments_x = [] segments_y = [] assert cell.rank == TWO_DIMENSION for points, _ in twocell_to_bz1(cell)[0]: transformed_points = apply_transformations(points, transformations) x, y, _ = np.concatenate([transformed_points, transformed_points[:1]]).T segments_x.append(x) segments_y.append(y) return segments_x, segments_y def twocell_to_bz1( cell: Cell, ) -> tuple[list[tuple[NDArray[np.float64], NDArray[np.float64]]], Cell, Cell]: icell = cell.reciprocal() bz1 = bz_vertices(icell, dim=cell.rank) return bz1, icell, cell