"""Interactive smoothing application for xarray DataArray using Panel and HoloViews.
This module defines a `SmoothingApp` class which provides a user interface for
applying smoothing filters (e.g., Gaussian) to 1D or 2D xarray DataArrays.
Users can interactively control which axes to smooth and filter parameters,
and visualize the results.
Dependencies:
- panel
- holoviews
- xarray
- arpes.analysis gaussian_filter_arr, savitzky_golay_filter, boxcar_filter_arr
"""
from __future__ import annotations
from logging import DEBUG, INFO
from typing import TYPE_CHECKING, Any, Unpack, cast
import holoviews as hv
import panel as pn
from holoviews.operation.datashader import regrid
from holoviews.streams import PointerX, PointerY
import arpes.xarray_extensions # pyright: ignore[reportUnusedImport] # noqa: F401
from arpes.analysis import (
boxcar_filter_arr,
curvature1d,
curvature2d,
dn_along_axis,
gaussian_filter_arr,
minimum_gradient,
savgol_filter_multi,
savitzky_golay_filter,
)
from arpes.constants import TWO_DIMENSION
from arpes.debug import setup_logger
from arpes.preparation import normalize_max
from ._helper import fix_xarray_to_fit_with_holoview, get_image_options, get_plot_lim
from .base import BaseUI, image_with_pointer, profile_curve
if TYPE_CHECKING:
from collections.abc import Callable, Hashable
import xarray as xr
from param.parameterized import Event
from arpes._typing.plotting import ProfileViewParam
LOGLEVELS = (DEBUG, INFO)
LOGLEVEL = LOGLEVELS[0]
logger = setup_logger(__name__, LOGLEVEL)
hv.extension("bokeh", logo=False)
pn.extension()
[docs]
class SmoothingApp(BaseUI):
"""An interactive smoothing UI for xarray DataArray using Panel and HoloViews."""
[docs]
def __init__(self, data: xr.DataArray, **kwargs: Unpack[ProfileViewParam]) -> None:
"""Initialize the SmoothingApp with data and parameters.
Args:
data (xr.DataArray): Input data to be smoothed.
**kwargs: Additional parameters for the UI, such as pane_kwargs.
"""
super().__init__(data, **kwargs)
max_coords = data.G.argmax_coords()
self.posx = PointerX(x=max_coords[data.dims[0]])
if data.ndim == TWO_DIMENSION:
self.posy = PointerY(y=max_coords[data.dims[1]])
self._build()
def _build(self) -> None:
self.pane_kwargs["height"] = 400
self.pane_kwargs["width"] = 450
self.pane_kwargs.setdefault("colorbar", True)
self.pane_kwargs.setdefault("profile_view_height", 130)
self.smoothing_funcs: dict[
str,
tuple[
Callable[..., xr.DataArray],
dict[Hashable, pn.widgets.Widget],
],
] = {
"None": (lambda x: x, {}),
"Gaussian": (
self._gaussian_smoothing,
_gaussian_slider(self.data),
),
"Savitzky-Golay": (
self._savitzky_golay_smoothing,
_savgol_slider(self.data),
),
"Boxcar": (
self._boxcar_smoothing,
_boxcar_slider(self.data),
),
}
self.smoothing_select = pn.widgets.Select(
name="Smoothing Function",
options=list(
self.smoothing_funcs,
),
)
self.output_name = pn.widgets.TextInput(
name="Output Name",
placeholder="e.g., smoothed1",
)
self._update_plot()
self.output_button = pn.widgets.Button(
name="Apply",
button_type="primary",
)
self.output_button.on_click(self._on_apply)
self.param_widgets_box = pn.Column()
self._update_smooth_param_widgets()
self.smoothing_select.param.watch(
self._update_smooth_param_widgets,
"value",
)
self.widgets_panel = pn.Column(
self.smoothing_select,
self.param_widgets_box,
self.output_name,
self.output_button,
)
self.layout = pn.Row(
self.output_pane,
pn.Column(
self.widgets_panel,
pn.layout.Divider(),
self.message_pane,
),
)
def _get_current_params(self) -> dict[str, float | int]:
"""Retrieve current values from parameter widgets.
Returns:
dict[str, float | int]: Parameter names and their current values.
"""
_, param_widgets = self.smoothing_funcs[str(self.smoothing_select.value)]
return {
str(name): widget.value
for name, widget in param_widgets.items()
if isinstance(widget.value, (float, int))
}
def _update_smooth_param_widgets(self, *_: Event) -> None:
"""Update the parameter widgets based on the selected smoothing function."""
__, param_widgets = self.smoothing_funcs[str(self.smoothing_select.value)]
self.param_widgets_box.objects = list(param_widgets.values())
def _on_apply(self, _: Event) -> None:
"""Callback when Apply button is clicked. Applies the selected filter."""
smooth_func, __ = self.smoothing_funcs[str(self.smoothing_select.value)]
kwargs = self._get_current_params()
self.output = smooth_func(self.data, **kwargs)
name = self.output_name.value
if name:
self.named_output[name] = self.output
self._update_plot()
def _update_plot(self) -> None:
"""Update the HoloViews plot with the current (smoothed) data."""
plot_data = fix_xarray_to_fit_with_holoview(self.output)
plot_data_orig = fix_xarray_to_fit_with_holoview(self.data)
if plot_data.ndim == 1:
curve = hv.Curve(plot_data, kdims=[plot_data.dims[0]])
self.output_pane.object = curve.opts(height=self.pane_kwargs["height"])
elif plot_data.ndim == TWO_DIMENSION:
max_coords = plot_data.G.argmax_coords()
self.posx = PointerX(x=max_coords[plot_data.dims[0]])
self.posy = PointerY(y=max_coords[plot_data.dims[1]])
image_options = get_image_options(
log=self.pane_kwargs["log"],
cmap=self.pane_kwargs["cmap"],
width=self.pane_kwargs["width"],
height=self.pane_kwargs["height"],
)
image_options["colorbar"] = self.pane_kwargs["colorbar"]
plot_lim = get_plot_lim(plot_data_orig, log=self.pane_kwargs["log"])
img = image_with_pointer(
data=plot_data,
use_quadmesh=True,
posx=self.posx,
posy=self.posy,
**image_options,
)
profile_x_smoothed = profile_curve(
data=plot_data,
stream=self.posx,
orientation="x",
plot_lim=plot_lim,
profile_size=self.pane_kwargs["profile_view_height"],
log=self.pane_kwargs["log"],
)
profile_y_smoothed = profile_curve(
data=plot_data,
stream=self.posy,
orientation="y",
plot_lim=plot_lim,
profile_size=self.pane_kwargs["profile_view_height"],
log=self.pane_kwargs["log"],
)
profile_x_original = profile_curve(
data=plot_data_orig,
stream=self.posx,
orientation="x",
plot_lim=plot_lim,
profile_size=self.pane_kwargs["profile_view_height"],
line_color="black",
line_width=1,
log=self.pane_kwargs["log"],
)
profile_y_original = profile_curve(
data=plot_data_orig,
stream=self.posy,
orientation="y",
plot_lim=plot_lim,
profile_size=self.pane_kwargs["profile_view_height"],
line_color="black",
line_width=1,
log=self.pane_kwargs["log"],
)
self.output_pane.object = (
img
<< (profile_x_original * profile_x_smoothed)
<< (profile_y_original * profile_y_smoothed)
)
def _gaussian_smoothing(self, data: xr.DataArray, **kwargs: float) -> xr.DataArray:
iteration = kwargs.pop("iteration", 1)
sigma = cast("dict[Hashable, float]", kwargs)
return gaussian_filter_arr(
arr=data,
sigma=sigma,
iteration_n=int(iteration),
)
def _savitzky_golay_smoothing(self, data: xr.DataArray, **kwargs: Any) -> xr.DataArray:
axis_params: dict[str, tuple[int, int]] = {}
for k, v in kwargs.items():
param_name, axis_name = k.rsplit("_", 1)
if axis_name not in axis_params:
axis_params[axis_name] = (1, 0)
if param_name == "window_length":
axis_params[axis_name] = (int(v), axis_params[axis_name][1])
elif param_name == "polyorder":
axis_params[axis_name] = (axis_params[axis_name][0], int(v))
else:
msg = f"❌ Unknown parameter {param_name} in Savitzky-Golay smoothing.\n"
raise ValueError(msg)
for v in axis_params.values():
if v[0] % 2 == 0:
self.log_message("❌ Window length must be odd for Savitzky-Golay filter.\n")
return data
if v[0] < v[1]:
self.log_message("❌ Polyorder must be less than window_length.\n")
return data
return savgol_filter_multi(data, axis_params=axis_params)
def _boxcar_smoothing(self, data: xr.DataArray, **kwargs: float) -> xr.DataArray:
iteration = int(kwargs.pop("iteration", 1))
size = cast("dict[Hashable, float]", kwargs)
return boxcar_filter_arr(
arr=data,
size=size,
iteration_n=iteration,
)
[docs]
class DifferentiateApp(SmoothingApp):
"""An interactive differentiation UI for xarray DataArray using Panel and HoloViews.
After smoothing, Differentiate, Maximum curvaure (1D, 2D) and Minimum gradient techniques
applied to find the peak position.
"""
[docs]
def __init__(self, data: xr.DataArray, **kwargs: Unpack[ProfileViewParam]) -> None:
"""Initialize the DifferentiationApp with data and parameters.
Args:
data (xr.DataArray): Input data to be differentiated.
**kwargs: Additional parameters for the UI, such as pane_kwargs.
"""
super().__init__(data, **kwargs)
self.max_intensity = data.max().item()
def _build(self) -> None:
"""Build the differentiation UI components."""
super()._build()
self.derivative_funcs: dict[
str,
tuple[
Callable[..., xr.DataArray],
dict[Hashable, pn.widgets.Widget],
],
] = {
"None": (lambda x: x, {}),
"Derivative": (
self._derivative,
_derivative_slider(self.data),
),
"n-th Derivative by Savitzky-Golay filter": (
self._n_th_derivative_with_SG,
_savgol_deriv_slider(self.data),
),
"Maximum curvature (1D)": (
self._maximum_curvature_1d,
_max_curvature_1d_slider(self.data),
),
"Maximum curvature (2D)": (
self._maximum_curvature_2d,
_max_curvature_2d_slider(),
),
"Minimum Gradient": (
self._minimum_gradient,
{},
),
}
self.derivation_select = pn.widgets.Select(
name="Derivative Function",
options=list(
self.derivative_funcs,
),
)
self.derivative_param_widgets_box = pn.Column()
self._update_derivative_param_widgets()
self.derivation_select.param.watch(self._update_derivative_param_widgets, "value")
self.widgets_panel = pn.Column(
self.smoothing_select,
self.param_widgets_box,
self.output_name,
self.derivation_select,
self.derivative_param_widgets_box,
self.output_button,
)
self.layout = pn.Row(
self.output_pane,
pn.Column(
self.widgets_panel,
pn.layout.Divider(),
self.message_pane,
),
)
def _update_derivative_param_widgets(self, *_: Event) -> None:
"""Update the parameter widgets based on the selected smoothing function."""
__, param_widgets = self.derivative_funcs[str(self.derivation_select.value)]
self.derivative_param_widgets_box.objects = list(param_widgets.values())
def _on_apply(self, _: Event) -> None:
"""Callback when Apply button is clicked.ArithmeticError.
Applies the selected filter and then selected derivative procedure.
"""
smooth_func, __ = self.smoothing_funcs[str(self.smoothing_select.value)]
kwargs = self._get_current_params()
self.output = smooth_func(self.data, **kwargs)
derivative_func, __ = self.derivative_funcs[str(self.derivation_select.value)]
derivative_kwargs = self._get_current_derivative_params()
self.output = derivative_func(self.output, **derivative_kwargs)
name = self.output_name.value
if name:
self.named_output[name] = self.output
self._update_plot()
def _update_plot0(self) -> None:
"""Update the HoloViews plot with the current (smoothed) data."""
plot_data = self.output
if plot_data.ndim == 1:
curve = hv.Curve(plot_data, kdims=[plot_data.dims[0]])
self.output_pane.object = curve.opts(height=self.pane_kwargs["height"])
elif plot_data.ndim == TWO_DIMENSION:
image_options = get_image_options(
log=self.pane_kwargs["log"],
cmap=self.pane_kwargs["cmap"],
width=self.pane_kwargs["width"],
height=self.pane_kwargs["height"],
)
image_options["xlabel"] = plot_data.dims[1]
image_options["ylabel"] = plot_data.dims[0]
image_options["colorbar"] = self.pane_kwargs["colorbar"]
img = hv.Image(
(
plot_data.coords[plot_data.dims[1]],
plot_data.coords[plot_data.dims[0]],
plot_data.values,
),
)
self.output_pane.object = regrid(img).opts(**image_options)
def _get_current_derivative_params(self) -> dict[str, float | int | str]:
"""Retrieve current values from parameter widgets.
Returns:
dict[str, float | int | str]: Parameter names and their current values.
"""
_, param_widgets = self.derivative_funcs[str(self.derivation_select.value)]
return {
str(k): v.value
for k, v in param_widgets.items()
if isinstance(v.value, (float, int, str))
}
def _derivative(self, data: xr.DataArray, **kwargs: int) -> xr.DataArray:
axis = kwargs.get("axis", data.dims[0])
return dn_along_axis(data, dim=axis, order=kwargs.get("derivative_order", 1))
def _n_th_derivative_with_SG(self, data: xr.DataArray, **kwargs: int) -> xr.DataArray:
"""Apply second derivative using Savitzky-Golay filter.
Args:
data (xr.DataArray): Input data to be processed.
**kwargs: Parameters for the Savitzky-Golay filter.
Returns:
xr.DataArray: The second derivative of the input data.
"""
axis = kwargs.get("axis", data.dims[0])
order = kwargs.get("order", 1)
window_length = kwargs.get("window_length", 5)
polyorder = kwargs.get("polyorder", 1)
if window_length % 2 == 0:
self.log_message("❌ Window length must be odd for Savitzky-Golay filter.\n")
return data
if polyorder <= order:
self.log_message("❌ Polyorder must be larger than Order\n")
return data
if window_length < polyorder:
self.log_message("❌ Polyorder must be less than window_length.\n")
return data
self.log_message("✅ sign-revered 2nd derivative is used, as it has a phyiscal meaning.\n")
filterd = savitzky_golay_filter(
data=data,
window_length=window_length,
polyorder=polyorder,
deriv=order,
dim=axis,
)
return -normalize_max(
filterd,
absolute=True,
keep_attrs=True,
max_value=self.max_intensity,
)
def _maximum_curvature_1d(self, data: xr.DataArray, **kwargs: int) -> xr.DataArray:
axis = kwargs.get("axis", data.dims[0])
self.log_message("✅ sign-revered curvature is used, as it has a phyiscal meaning.\n")
return -normalize_max(
curvature1d(data, dim=axis, alpha=kwargs.get("coefficient a", 0.1)),
absolute=True,
keep_attrs=True,
max_value=self.max_intensity,
)
def _maximum_curvature_2d(self, data: xr.DataArray, **kwargs: int) -> xr.DataArray:
dims = cast("tuple[Hashable, Hashable]", kwargs.get("dims", data.dims))
if kwargs.get("weight_2D", 1.0) == 0:
self.log_message("❌ weight 2D must not be 0\n")
return data
self.log_message("✅ sign-revered curvature is used, as it has a phyiscal meaning.\n")
return -normalize_max(
curvature2d(
data,
dims=dims,
alpha=kwargs.get("coefficient a", 0.1),
weight2d=kwargs.get("weight_2D", 1.0),
),
absolute=True,
keep_attrs=True,
max_value=self.max_intensity,
)
def _minimum_gradient(self, data: xr.DataArray, **kwargs: int) -> xr.DataArray:
del kwargs
return minimum_gradient(data)
# --------- Helper Functions ---------#
def _derivative_slider(data: xr.DataArray) -> dict[Hashable, pn.widgets.Widget]:
"""Generate a dictionary of sliders for derivative.
Args:
data(xr.DataArray): DataArray to be processed.
Returns:
dict[str, pn.widgets.Widget]: A dictionary of slider widgets.
"""
return {
"axis": pn.widgets.Select(name="axis", options=list(data.dims)),
"derivative_order": pn.widgets.IntSlider(
name="Derivative Order",
value=1,
start=1,
end=10,
step=1,
),
}
def _savgol_deriv_slider(data: xr.DataArray) -> dict[Hashable, pn.widgets.Widget]:
"""Generate a dictionary of sliders for Savitzky-Golay derivative.
Args:
data(xr.DataArray): DataArray to be processed.
Returns:
dict[str, pn.widgets.Widget]: A dictionary of slider widgets.
"""
return {
"axis": pn.widgets.Select(name="axis", options=list(data.dims)),
"order": pn.widgets.IntSlider(value=1, start=1, end=6, step=1, name="Order"),
"window_length": pn.widgets.IntSlider(
name="Window Length",
start=1,
end=25,
step=2,
value=5,
),
"polyorder": pn.widgets.IntSlider(
name="Polyorder",
start=0,
end=6,
step=1,
value=1,
),
}
def _max_curvature_1d_slider(data: xr.DataArray) -> dict[Hashable, pn.widgets.Widget]:
"""Generate a dictionary of sliders for 1D maximum curvature.
Args:
data(xr.DataArray): DataArray to be processed.
Returns:
dict[str, pn.widgets.Widget]: A dictionary of slider widgets.
"""
return {
"axis": pn.widgets.Select(name="axis", options=list(data.dims)),
"coefficient a": pn.widgets.FloatSlider(
name="Coefficient a",
value=0.1,
start=0.0,
end=1,
step=0.0001,
format="0.0000",
),
}
def _max_curvature_2d_slider() -> dict[Hashable, pn.widgets.Widget]:
"""Generate a dictionary of sliders for 2D maximum curvature.
Returns:
dict[str, pn.widgets.Widget]: A dictionary of slider widgets.
"""
return {
"coefficient a": pn.widgets.FloatSlider(
name="Coefficient a",
value=0.1,
start=0.0,
end=1,
step=0.0001,
format="0.0000",
),
"weight_2D": pn.widgets.FloatSlider(
name="Weight 2D",
start=-10.0,
end=10.0,
step=0.001,
value=1.0,
format="0.000",
),
}
def _iteration_slider() -> dict[Hashable, pn.widgets.Widget]:
"""Generate a dictionary of iteration sliders.
Returns:
dict[str, pn.widgets.Widget]: A dictionary of slider widgets.
"""
return {
"iteration": pn.widgets.IntSlider(
name="Iteration",
value=1,
start=1,
end=10,
step=1,
),
}
def _gaussian_slider(data: xr.DataArray) -> dict[Hashable, pn.widgets.Widget]:
"""Generate a dictionary of Gaussian smoothing sliders.
Args:
data(xr.DataArray): DataArray to be smoothed.
Returns:
dict[str, pn.widgets.Widget]: A dictionary of slider widgets.
"""
sliders = _iteration_slider()
for dim in data.dims:
sliders[dim] = pn.widgets.FloatSlider(
name=f"Sigma {dim}",
start=0,
end=round(data.G.stride(generic_dim_names=False)[dim].item() * 100, 2),
step=0.001,
value=0.1,
format="0.000",
)
return sliders
def _boxcar_slider(data: xr.DataArray) -> dict[Hashable, pn.widgets.Widget]:
"""Generate a dictionary of boxcar smoothing sliders.
Args:
data(xr.DataArray): DataArray to be smoothed.
Returns:
dict[str, pn.widgets.Widget]: A dictionary of slider widgets.
"""
sliders = _iteration_slider()
for dim in data.dims:
sliders[dim] = pn.widgets.FloatSlider(
name=f"Kernel Size {dim}",
start=0.0,
end=round(data.G.stride(generic_dim_names=False)[dim].item() * 100, 2),
step=0.001,
value=0.1,
format="0.000",
)
return sliders
def _savgol_slider(data: xr.DataArray) -> dict[Hashable, pn.widgets.Widget]:
"""Generate a dictionary of Savitzky-Golay smoothing sliders.
Args:
data(xr.DataArray): DataArray to be smoothed.
Returns:
dict[str, pn.widgets.Widget]: A dictionary of slider widgets.
"""
sliders: dict[Hashable, pn.widgets.Widget] = {}
for dim in data.dims:
sliders[f"window_length_{dim}"] = pn.widgets.IntSlider(
name=f"Window Length {dim}",
start=1,
end=25,
step=2,
value=5,
)
sliders[f"polyorder_{dim}"] = pn.widgets.IntSlider(
name=f"Polyorder {dim}",
start=0,
end=6,
step=1,
value=1,
)
return sliders