"""Utilities related to interpretation of model results.
This borrows ideas heavily from fastai which provides interpreter classes
for different kinds of models.
"""
from dataclasses import dataclass, field
import math
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset
import pytorch_lightning as pl
import torch
import tqdm
from typing import List, Any, Optional, Tuple, Union
__all__ = [
"Interpretation",
"InterpretationItem",
]
[docs]@dataclass
class InterpretationItem:
"""Provides tools to introspect model performance on a single item."""
target: Any
predicted_target: Any
loss: float
index: int
parent_dataloader: DataLoader
@property
def dataset(self):
"""Fetches the original dataset used to train and containing this item.
We need to unwrap the dataset in case we are actually dealing
with a Subset. We should obtain an indexed Dataset at the end
of the day, and we will know this is the case because we use
the sentinel attribute `is_indexed` to mark this.
This may fail sometimes, but this is better than returning junk
data which is what happens if we get a shuffled view over the
dataset.
"""
dset = self.parent_dataloader.dataset
if isinstance(dset, Subset):
dset = dset.dataset
assert dset.is_indexed == True
return dset
def show(self, input_formatter, target_formatter, ax=None, pullback=True):
"""Plots this item onto the provided axes. See also the `show` method of `Interpretation`."""
if ax is None:
_, ax = plt.subplots()
dset = self.dataset
with dset.no_transforms():
x = dset[self.index][0]
if input_formatter is not None:
input_formatter.show(x, ax)
ax.set_title(
"Item {index}; loss={loss:.3f}\n".format(index=self.index, loss=float(self.loss))
)
if target_formatter is not None:
if hasattr(target_formatter, "context"):
target_formatter.context = dict(is_ground_truth=True)
target = self.decodes_target(self.target) if pullback else self.target
target_formatter.show(target, ax)
if hasattr(target_formatter, "context"):
target_formatter.context = dict(is_ground_truth=False)
predicted = (
self.decodes_target(self.predicted_target) if pullback else self.predicted_target
)
target_formatter.show(predicted, ax)
def decodes_target(self, value: Any) -> Any:
"""Pulls the predicted target backwards through the transformation stack.
Pullback continues until an irreversible transform is met in order
to be able to plot targets and predictions in a natural space.
"""
tfm = self.dataset.transforms
if hasattr(tfm, "decodes_target"):
return tfm.decodes_target(value)
return value
[docs]@dataclass
class Interpretation:
"""Provides utilities to interpret predictions of a model.
Importantly, this is not intended to provide any model introspection
tools.
"""
model: pl.LightningModule
train_dataloader: DataLoader
val_dataloaders: DataLoader
train: bool = True
val_index: int = 0
train_items: List[InterpretationItem] = field(init=False, repr=False)
val_item_lists: List[List[InterpretationItem]] = field(init=False, repr=False)
@property
def items(self) -> List[InterpretationItem]:
"""All of the ``InterpretationItem`` instances inside this instance."""
if self.train:
return self.train_items
return self.val_item_lists[self.val_index]
def top_losses(self, ascending=False) -> List[InterpretationItem]:
"""Orders the items by loss."""
key = lambda item: item.loss if ascending else -item.loss
return sorted(self.items, key=key)
def show(
self,
n_items: Optional[Union[int, Tuple[int, int]]] = 9,
items: Optional[List[InterpretationItem]] = None,
input_formatter=None,
target_formatter=None,
) -> None:
"""Plots a subset of the interpreted items.
For each item, we "plot" its data, its label, and model performance characteristics
on this item.
For example, on an image classification task this might mean to plot the image,
the images class name as a label above it, the predicted class, and the numerical loss.
"""
layout = None
if items is None:
if isinstance(n_items, (tuple, list)):
layout = n_items
else:
n_rows = int(math.ceil(n_items ** 0.5))
layout = (n_rows, n_rows)
items = self.top_losses()[:n_items]
else:
n_items = len(items)
n_rows = int(math.ceil(n_items ** 0.5))
layout = (n_rows, n_rows)
_, axes = plt.subplots(*layout, figsize=(layout[0] * 3, layout[1] * 4))
items_with_nones = list(items) + [None] * (np.product(layout) - n_items)
for item, ax in zip(items_with_nones, axes.ravel()):
if item is None:
ax.axis("off")
else:
item.show(input_formatter, target_formatter, ax)
plt.tight_layout()
@classmethod
def from_trainer(cls, trainer: pl.Trainer):
"""Builds an interpreter from an instance of a `pytorch_lightning.Trainer`."""
return cls(trainer.model, trainer.train_dataloader, trainer.val_dataloaders)
def dataloader_to_item_list(self, dataloader: DataLoader) -> List[InterpretationItem]:
"""Converts a data loader into a list of interpretation items corresponding to the data samples."""
items = []
for batch in tqdm.tqdm(dataloader.iter_all()):
x, y, indices = batch
with torch.no_grad():
y_hat = self.model(x).cpu()
y_hats = torch.unbind(y_hat, axis=0)
ys = torch.unbind(y, axis=0)
losses = [self.model.criterion(yi_hat, yi) for yi_hat, yi in zip(y_hats, ys)]
for (yi, yi_hat, loss, index) in zip(ys, y_hats, losses, torch.unbind(indices, axis=0)):
items.append(
InterpretationItem(
torch.squeeze(yi),
torch.squeeze(yi_hat),
torch.squeeze(loss),
int(index),
dataloader,
)
)
return items
def __post_init__(self):
"""Populates train_items and val_item_lists.
This is done by iterating through the dataloaders and pushing data through the models.
"""
self.train_items = self.dataloader_to_item_list(self.train_dataloader)
self.val_item_lists = [self.dataloader_to_item_list(dl) for dl in self.val_dataloaders]