"""Utilities related to interpretation of model results.
This borrows ideas heavily from fastai which provides interpreter classes
for different kinds of models.
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from itertools import starmap
from typing import TYPE_CHECKING, Any
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data.dataset import Dataset, Subset
from arpes.helper.jupyter import get_tqdm
if TYPE_CHECKING:
import pytorch_lightning as pl
from _typeshed import Incomplete
from matplotlib.axes import Axes
from torch.utils.data import DataLoader
__all__ = [
"Interpretation",
"InterpretationItem",
]
tqdm = get_tqdm()
[docs]
@dataclass
class InterpretationItem:
"""Provides tools to introspect model performance on a single item."""
target: Any
predicted_target: Any
loss: float
index: int
parent_dataloader: DataLoader
@property
def dataset(self) -> Dataset:
"""Fetches the original dataset used to train and containing this item.
We need to unwrap the dataset in case we are actually dealing
with a Subset. We should obtain an indexed Dataset at the end
of the day, and we will know this is the case because we use
the sentinel attribute `is_indexed` to mark this.
This may fail sometimes, but this is better than returning junk
data which is what happens if we get a shuffled view over the
dataset.
"""
dset = self.parent_dataloader.dataset
if isinstance(dset, Subset):
dset = dset.dataset
assert dset.is_indexed is True
return dset
def show(
self,
input_formatter: Incomplete,
target_formatter: Incomplete,
ax: Axes | None = None,
*,
pullback: bool = True,
) -> None:
"""Plots item onto the provided axes. See also the `show` method of `Interpretation`."""
if ax is None:
_, ax = plt.subplots()
dset = self.dataset
with dset.no_transforms():
x = dset[self.index][0]
if input_formatter is not None:
input_formatter.show(x, ax)
ax.set_title(
f"Item {self.index}; loss={float(self.loss):.3f}\n",
)
if target_formatter is not None:
if hasattr(target_formatter, "context"):
target_formatter.context = {"is_ground_truth": True}
target = self.decodes_target(self.target) if pullback else self.target
target_formatter.show(target, ax)
if hasattr(target_formatter, "context"):
target_formatter.context = {"is_ground_truth": False}
predicted = (
self.decodes_target(self.predicted_target) if pullback else self.predicted_target
)
target_formatter.show(predicted, ax)
def decodes_target(self, value: Incomplete) -> Incomplete:
"""Pulls the predicted target backwards through the transformation stack.
Pullback continues until an irreversible transform is met in order
to be able to plot targets and predictions in a natural space.
"""
tfm = self.dataset.transforms
if hasattr(tfm, "decodes_target"):
return tfm.decodes_target(value)
return value
[docs]
@dataclass
class Interpretation:
"""Provides utilities to interpret predictions of a model.
Importantly, this is not intended to provide any model introspection
tools.
"""
model: pl.LightningModule
train_dataloader: DataLoader
val_dataloaders: DataLoader
train: bool = True
val_index: int = 0
train_items: list[InterpretationItem] = field(init=False, repr=False)
val_item_lists: list[list[InterpretationItem]] = field(init=False, repr=False)
@property
def items(self) -> list[InterpretationItem]:
"""All of the ``InterpretationItem`` instances inside this instance."""
if self.train:
return self.train_items
return self.val_item_lists[self.val_index]
def top_losses(self, *, ascending: bool = False) -> list[InterpretationItem]:
"""Orders the items by loss."""
def key(item: Incomplete) -> Incomplete:
return item.loss if ascending else -item.loss
return sorted(self.items, key=key)
def show(
self,
n_items: int | tuple[int, int] = 9,
items: list[InterpretationItem] | None = None,
input_formatter: Incomplete = None,
target_formatter: Incomplete = None,
) -> None:
"""Plots a subset of the interpreted items.
For each item, we "plot" its data, its label, and model performance characteristics
on this item.
For example, on an image classification task this might mean to plot the image,
the images class name as a label above it, the predicted class, and the numerical loss.
"""
layout = None
if items is None:
if isinstance(n_items, tuple):
layout = n_items
else:
n_rows = math.ceil(n_items**0.5)
layout = (n_rows, n_rows)
items = self.top_losses()[:n_items]
else:
n_items = len(items)
n_rows = math.ceil(n_items**0.5)
layout = (n_rows, n_rows)
assert isinstance(n_items, int)
_, axes = plt.subplots(*layout, figsize=(layout[0] * 3, layout[1] * 4))
items_with_nones = list(items) + [None] * (np.prod(layout) - n_items)
for item, ax in zip(items_with_nones, axes.ravel(), strict=True):
if item is None:
ax.axis("off")
else:
item.show(input_formatter, target_formatter, ax)
plt.tight_layout()
@classmethod
def from_trainer(cls: type[Incomplete], trainer: pl.Trainer) -> list[InterpretationItem]:
"""Builds an interpreter from an instance of a `pytorch_lightning.Trainer`."""
return cls(trainer.model, trainer.train_dataloader, trainer.val_dataloaders)
def dataloader_to_item_list(self, dataloader: DataLoader) -> list[InterpretationItem]:
"""Converts data loader into a list of interpretation items corresponding to the data."""
items = []
for batch in tqdm(dataloader.iter_all()):
x, y, indices = batch
with torch.no_grad():
y_hat = self.model(x).cpu()
y_hats = torch.unbind(y_hat, axis=0)
ys = torch.unbind(y, axis=0)
losses = list(starmap(self.model.criterion, zip(y_hats, ys, strict=True)))
for yi, yi_hat, loss, index in zip(
ys,
y_hats,
losses,
torch.unbind(indices, axis=0),
strict=True,
):
items.append(
InterpretationItem(
torch.squeeze(yi),
torch.squeeze(yi_hat),
int(index),
torch.squeeze(loss),
dataloader,
),
)
return items
def __post_init__(self) -> None:
"""Populates train_items and val_item_lists.
This is done by iterating through the dataloaders and pushing data through the models.
"""
self.train_items = self.dataloader_to_item_list(self.train_dataloader)
self.val_item_lists = [self.dataloader_to_item_list(dl) for dl in self.val_dataloaders]