Source code for arpes.deep_learning.transforms

"""Implements transform pipelines for pytorch_lightning with basic inverse transform."""

from __future__ import annotations

from dataclasses import Field, dataclass, field
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from _typeshed import Incomplete

__all__ = ["ComposeBoth", "Identity", "ReversibleLambda"]


[docs] class Identity: """Represents a reversible identity transform.""" def encodes(self, x: Incomplete) -> Incomplete: """[TODO:summary]. Args: x: [TODO:description] Returns: [TODO:description] """ return x def __call__(self, x: Incomplete) -> Incomplete: """[TODO:summary]. Args: x: [TODO:description] Returns: [TODO:description] """ return x def decodes(self, x: Incomplete) -> Incomplete: """[TODO:summary]. Args: x: [TODO:description] Returns: [TODO:description] """ return x def __repr__(self) -> str: """[TODO:summary]. Returns: [TODO:description] """ return "Identity()"
_identity = Identity()
[docs] @dataclass class ReversibleLambda: """A reversible anonymous function, so long as the caller supplies an inverse.""" encodes: Field = field(repr=False) decodes: Field = field(default=lambda x: x, repr=False) def __call__(self, value: Incomplete) -> Field[Incomplete]: """Apply the inner lambda to the data in forward pass.""" return self.encodes(value)
[docs] @dataclass class ComposeBoth: """Like `torchvision.transforms.Compose` but it operates on data & target in each transform.""" transforms: list[Any] def __post_init__(self) -> None: """Replace missing transforms with identities.""" safe_transforms = [] for t in self.transforms: if isinstance(t, tuple | list): xt, yt = t safe_transforms.append([xt or _identity, yt or _identity]) else: safe_transforms.append(t) self.original_transforms = self.transforms self.transforms = safe_transforms def __call__(self, x: Incomplete, y: Incomplete) -> Incomplete: """If this transform has separate data and target functions, apply separately. Otherwise, we apply the single transform to both the data and the target. """ for t in self.transforms: if isinstance(t, list | tuple): xt, yt = t x, y = xt(x), yt(y) else: x, y = t(x, y) return x, y def decodes_target(self, y: Incomplete) -> Incomplete: """Pull the target back in the transform stack as far as possible. This is necessary only for the predicted target because otherwise we can *always* push the ground truth target and input forward in the transform stack. This is imperfect because for some transforms we need X and Y in order to process the data. """ for t in self.transforms[::-1]: if isinstance(t, list | tuple): _, yt = t y = yt.decodes(y) else: break return y def __repr__(self) -> str: """Show both of the constituent parts of this transform.""" return ( self.__class__.__name__ + "(\n\t" + "\n\t".join([str(t) for t in self.original_transforms]) + "\n)" )