nn

class snippets.nn.Affine(loc: Tensor, scale: Tensor, dtype: dtype | None = None)

Apply a fixed affine transform \(y = x A^\intercal + b\) akin to torch.nn.Linear.

Parameters:
  • loc – Offset of the transform.

  • scale – Scale (matrix) of the transform.

Example

>>> from snippets.nn import Affine
>>> import torch

>>> affine = Affine(loc=1, scale=2)
>>> x = torch.arange(3)
>>> affine(x)
tensor([1., 3., 5.])
class snippets.nn.StopOnPlateau(mode: Literal['min', 'max'] = 'min', patience: int = 20, threshold: float = 0.0001, threshold_mode: Literal['rel', 'abs'] = 'rel')

Stop training when a monitored quantity plateaus akin to torch.optim.lr_scheduler.ReduceLROnPlateau.

Parameters:
  • mode – One of min or max. In min mode, training will stop when the quantity monitored has stopped decreasing; in max mode training will stop when the quantity monitored has stopped increasing.

  • patience – Number of epochs without improvement after which training will stop.

  • threshold – Threshold for measuring the new optimum, to only focus on significant changes.

  • threshold_mode – One of rel or abs. For rel, dynamic_threshold = best * (1 + threshold) in max mode or dynamic_threshold = best * (1 - threshold) in min mode. For abs, dynamic_threshold = best + threshold in max mode or :code:dynamic_threshold = best - threshold` in min mode.

Example

>>> from snippets.nn import StopOnPlateau

>>> stop = StopOnPlateau(patience=3)
>>> for _ in range(5):
...     stop.step(3)
False
False
False
False
True
classmethod from_scheduler(scheduler: ReduceLROnPlateau, patience_factor: float = 1, patience: int | None = None) S

Create a StopOnPlateau instance configured based on a ReduceLROnPlateau.

Parameters:
  • scheduler – Learning rate scheduler whose configuration to copy.

  • patience_factor – Factor by which to scale the patience of the learning rate scheduler.

  • patience – Patience of the instance (takes precedence over patience_factor).

Returns:

Instance configured based on supplied ReduceLROnPlateau.

is_better(best: float, candidate: float) bool

Check if the candidate is better than the current best value.

Parameters:
  • best – Reference value to compare with.

  • candidate – Candidate value to check.

Returns:

If the candidate is better than the current best value.

step(value: float) bool

Update the state with a new value.

Parameters:

value – Value of the monitored quantity.

Returns:

If the process should stop.

property stop: bool

If the process should stop.