Source code for arpes.plotting.fits
"""Utilities for inspecting fit results by hand by plotting them individually."""
from __future__ import annotations
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from .utils import simple_ax_grid
__all__ = (
"plot_fit",
"plot_fits",
)
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import lmfit as lf
from numpy.typing import NDArray
[docs]
def plot_fit(model_result: lf.model.ModelResult, ax: Axes | None = None) -> Axes:
"""Performs a straightforward plot of the data, residual, and fit to an axis.
When the "fit_results" is the return of S.modelfit, the argument of this function
is fit_results.modelfit_results[n].item(), where n is the index.
The role of this function is same as the ModelResult.plot(), but in
less space than it.
Args:
model_result: [TODO:description]
ax: Axes on which to plot.
Returns:
[TODO:description]
"""
if ax is None:
_, ax = plt.subplots()
assert isinstance(ax, Axes)
x = model_result.userkws[model_result.model.independent_vars[0]]
ax2 = ax.twinx()
assert isinstance(ax2, Axes)
ax2.grid(visible=False)
ax2.axhline(0, color="green", linestyle="--", alpha=0.5)
ax.scatter(
x,
model_result.data,
s=10,
edgecolors="blue",
marker="s",
c="white",
linewidth=1.5,
)
ax.plot(x, model_result.best_fit, color="red", linewidth=1.5)
ax2.scatter(
x,
model_result.residual,
edgecolors="green",
alpha=0.5,
s=12,
marker="s",
c="white",
linewidth=1.5,
)
ylim = np.max(np.abs(np.asarray(ax2.get_ylim()))) * 2.5
ax2.set_ylim(bottom=-ylim, top=ylim)
ax.set_xlim(left=np.min(x), right=np.max(x))
return ax
[docs]
def plot_fits(
model_results: list[lf.model.ModelResult] | NDArray[np.object_],
axs: NDArray[np.object_] | None = None,
) -> None:
"""Plots several fits onto a grid of axes.
Args:
model_results: [TODO:description]
axs: Axes on which to plot.
"""
n_results = len(model_results)
axs = axs or simple_ax_grid(n_results, sharex="col", sharey="row")[1]
for axi, model_result in zip(axs, model_results, strict=False):
plot_fit(model_result, ax=axi)