Source code for arpes.analysis.deconvolution

"""Provides deconvolution implementations, especially for 2D Richardson-Lucy."""

from __future__ import annotations

from logging import DEBUG, INFO
from typing import TYPE_CHECKING

import numpy as np
import scipy
import scipy.ndimage
import xarray as xr
from lmfit.lineshapes import gaussian
from scipy.stats import multivariate_normal
from skimage.restoration import richardson_lucy

from arpes.debug import setup_logger
from arpes.provenance import update_provenance
from arpes.utilities import normalize_to_spectrum

if TYPE_CHECKING:
    from collections.abc import Hashable

    from numpy.typing import NDArray


__all__ = (
    "deconvolve_ice",
    "deconvolve_rl",
    "make_psf",
    "make_psf1d",
)

LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[1]
logger = setup_logger(__name__, LOGLEVEL)


[docs] @update_provenance("Approximate Iterative Deconvolution") def deconvolve_ice( data: xr.DataArray, psf: NDArray[np.floating], n_iterations: int = 5, deg: int | None = None, ) -> xr.DataArray: """Deconvolves data by a given point spread function (PSF). The iterative convolution extrapolation method is used. The PSF is the impulse response of a focused optical imaging system. Args: data (xr.DataArray): input data psf: array as point spread function n_iterations (int): the number of convolutions to use for the fit deg (float): the degree of the fitting polynominial Returns: The deconvoled data in the same format. """ data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) arr: NDArray[np.floating] = data.values if deg is None: deg = n_iterations - 3 iteration_steps = list(range(1, n_iterations + 1)) iteration_list = [arr] for _ in range(n_iterations - 1): iteration_list.append( np.asarray( scipy.ndimage.convolve(iteration_list[-1], psf), dtype=np.float64, ), ) iteration_array = np.asarray(iteration_list) deconv = arr * 0 for t, series in enumerate(iteration_array.T): coefs = np.polyfit(iteration_steps, series, deg=deg) poly = np.poly1d(coefs) deconv[t] = poly(0) return data.G.with_values(deconv, keep_attrs=True)
[docs] @update_provenance("Lucy Richardson Deconvolution") def deconvolve_rl( data: xr.DataArray, psf: xr.DataArray, n_iterations: int = 10, ) -> xr.DataArray: """Deconvolves data by a given point spread function using the Richardson-Lucy (RL) method. Args: data: input data psf: The point spread function. n_iterations: the number of convolutions to use for the fit Returns: The Richardson-Lucy deconvolved data. """ data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) im_deconv = richardson_lucy( image=data.values, psf=psf.values, num_iter=n_iterations, filter_epsilon=None, ) return data.G.with_values(im_deconv, keep_attrs=True)
[docs] @update_provenance("Make 1D-Point Spread Function") def make_psf1d( data: xr.DataArray, dim: str, sigma: float, ) -> xr.DataArray: """Produces a 1-dimensional gaussian point spread function for use in deconvolve_rl. Args: data (DataType): xarray object dim (str): dimension name sigma (float): sigma value Returns: A one dimensional point spread array. """ data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) psf = data.G.with_values(np.zeros_like(data.values) + 1) other_dims = list(data.dims) other_dims.remove(dim) for od in other_dims: psf = psf[{od: 0}] return psf * gaussian(x=psf.coords[dim], center=np.mean(psf.coords[dim]), sigma=sigma)
[docs] @update_provenance("Make Point Spread Function") def make_psf( data: xr.DataArray, sigmas: dict[Hashable, float], *, fwhm: bool = True, clip: float | None = None, ) -> xr.DataArray: """Produces an n-dimensional gaussian point spread function for use in deconvolve_rl. Args: data (DataType): input data sigmas (dict[str, float]): sigma values for each dimension. fwhm (bool): if True, sigma is FWHM, not the standard deviation. clip (float | bool): clip the region by sigma-unit. Returns: The PSF to use. """ strides = data.G.stride(generic_dim_names=False) logger.debug(f"strides: {strides}") assert set(strides) == set(sigmas) pixels: dict[Hashable, int] = dict( zip( data.dims, [i - 1 if i % 2 == 0 else i for i in data.shape], strict=True, ), ) if fwhm: sigmas = {k: v / (2 * np.sqrt(2 * np.log(2))) for k, v in sigmas.items()} cov: NDArray[np.floating] = np.zeros((len(sigmas), len(sigmas)), dtype=np.float64) for i, dim in enumerate(data.dims): cov[i][i] = sigmas[dim] ** 2 # sigma is deviation, but multivariate_normal uses covariant logger.debug(f"cov: {cov}") psf_coords: dict[Hashable, NDArray[np.floating]] = {} for k in data.dims: psf_coords[str(k)] = np.linspace( -(pixels[str(k)] - 1) / 2 * strides[str(k)], (pixels[str(k)] - 1) / 2 * strides[str(k)], pixels[str(k)], ) for k, v in psf_coords.items(): logger.debug( "psf_coords[%s]: ±%.3f", k, np.max(v), ) coords = np.meshgrid(*[psf_coords[dim] for dim in data.dims], indexing="ij") coords_for_pdf_pos = np.stack(coords, axis=-1) # point distribution function (pdf) logger.debug(f"shape of coords_for_pdf_pos: {coords_for_pdf_pos.shape}") psf = xr.DataArray( data=multivariate_normal(mean=np.zeros(len(sigmas)), cov=cov).pdf( x=coords_for_pdf_pos, ), dims=data.dims, coords=psf_coords, name="PSF", ) if clip: clipping_region = {k: slice(-clip * v, clip * v) for k, v in sigmas.items()} return psf.sel(indexers=clipping_region) return psf