Source code for arpes.deep_learning.transforms

"""Implements transform pipelines for pytorch_lightning with basic inverse transform."""
from dataclasses import dataclass, field
from typing import Callable, List, Any

__all__ = ["ComposeBoth", "ReversibleLambda", "Identity"]

[docs]class Identity: """Represents a reversible identity transform.""" def encodes(self, x): return x def __call__(self, x): return x def decodes(self, x): return x def __repr__(self): return "Identity()"
_identity = Identity()
[docs]@dataclass class ReversibleLambda: """A reversible anonymous function, so long as the caller supplies an inverse.""" encodes: Callable = field(repr=False) decodes: Callable = field(default=lambda x: x, repr=False) def __call__(self, value): """Apply the inner lambda to the data in forward pass.""" return self.encodes(value)
[docs]@dataclass class ComposeBoth: """Like `torchvision.transforms.Compose` but this operates on data & target in each transform.""" transforms: List[Any] def __post_init__(self): """Replace missing transforms with identities.""" safe_transforms = [] for t in self.transforms: if isinstance(t, (tuple, list)): xt, yt = t t = [xt or _identity, yt or _identity] safe_transforms.append(t) self.original_transforms = self.transforms self.transforms = safe_transforms def __call__(self, x, y): """If this transform has separate data and target functions, apply separately. Otherwise, we apply the single transform to both the data and the target. """ for t in self.transforms: if isinstance(t, (list, tuple)): xt, yt = t x, y = xt(x), yt(y) else: x, y = t(x, y) return x, y def decodes_target(self, y): """Pull the target back in the transform stack as far as possible. This is necessary only for the predicted target because otherwise we can *always* push the ground truth target and input forward in the transform stack. This is imperfect because for some transforms we need X and Y in order to process the data. """ for t in self.transforms[::-1]: if isinstance(t, (list, tuple)): _, yt = t y = yt.decodes(y) else: break return y def __repr__(self): """Show both of the constitutent parts of this transform.""" return ( self.__class__.__name__ + "(\n\t" + "\n\t".join([str(t) for t in self.original_transforms]) + "\n)" )