Potpourri

Miscellaneous utilities.

Math

fannypack.utils.tril_count_from_matrix_dim(matrix_dim: int)

Computes the number of lower triangular terms in a square matrix of a given dimension (matrix_dim, matrix_dim).

Parameters

matrix_dim (int) – Dimension of square matrix.

Returns

int – Count of lower-triangular terms.

fannypack.utils.matrix_dim_from_tril_count(tril_count: int)

Computes the dimension of a lower triangular square matrix given a count of its lower-triangular components.

Parameters

tril_count (int) – Count of lower-triangular terms.

Returns

int – Dimension of square matrix.

fannypack.utils.tril_from_vector(lower_vector: torch.Tensor) → torch.Tensor

Computes lower-triangular square matrices from a flattened vector of nonzero terms. Supports arbitrary batch dimensions.

Parameters

lower_vector (torch.Tensor) – Vectors containing the nonzero terms of a square lower-triangular matrix. Shape should be (*, tril_count).

Returns

torch.Tensor – Square matrices. Shape should be (*, matrix_dim, matrix_dim).

fannypack.utils.vector_from_tril(tril_matrix: torch.Tensor) → torch.Tensor

Retrieves the lower triangular terms of square matrices as vectors. Supports arbitrary batch dimensions.

Parameters

tril_matrix (torch.Tensor) – Square matrices. Shape should be (*, matrix_dim, matrix_dim)

Returns

torch.Tensor – Flattened vectors. Shape should be (*, tril_count).

fannypack.utils.gaussian_log_prob(mean: torch.Tensor, covariance: torch.Tensor, value: torch.Tensor) → torch.Tensor

Computes log probabilities under multivariate Gaussian distributions, with support for arbitrary batch axes.

Naive version of…

torch.distributions.MultivariateNormal(
    mean, covariance
).log_prob(value)

that avoids some Cholesky-related CUDA errors. https://discuss.pytorch.org/t/cuda-illegal-memory-access-when-using-batched-torch-cholesky/51624

Stolen from @alberthli/@wuphilipp.

Parameters
  • mean (torch.Tensor) – Means vectors. Shape should be (*, D).

  • covariance (torch.Tensor) – Covariances matrices. Shape should be (*, D, D).

  • value (torch.Tensor) – State vectors. Shape should be (*, D).

Returns

torch.Tensor – Batched log probabilities. Shape should be (*).

fannypack.utils.quadratic_matmul(x: torch.Tensor, A: torch.Tensor) → torch.Tensor

Computes \(x^\top A x\), with support for arbitrary batch axes.

Stolen from @alberthli/@wuphilipp.

Parameters
  • x (torch.Tensor) – Vectors. Shape should be (*, D).

  • A (torch.Tensor) – Matrices. Shape should be (*, D, D).

Returns

torch.Tensor – Batched output of multiplication. Shape should be (*).

Module Freezing

fannypack.utils.freeze_module(module: torch.nn.modules.module.Module, recurse: bool = True) → None

Freeze the weights of a PyTorch module by setting the requires_grad attributes of enclosed parameters to False.

Parameters
  • module (torch.nn.Module) – Module to freeze.

  • recurse (bool, optional) – If True, then recursively freezes children. Otherwise, only freezes immediate parameters.

fannypack.utils.unfreeze_module(module: torch.nn.modules.module.Module, recurse: bool = True) → None

Unfreeze the weights of a PyTorch module, which needs to have been frozen with fannypack.utils.freeze_module(). Restores all original values requires_grad values.

Parameters
  • module (torch.nn.Module) – Module to unfreeze.

  • recurse (bool, optional) – If True, then recursively unfreezes children. Otherwise, only unfreezes immediate parameters.

Deprecation Helpers

fannypack.utils.new_name_wrapper(old_name: str, new_name: str, function_or_class: Callable) → Callable

Creates a wrapper for a renamed function or class. Prints a warning the first time a function or class is called with the old name.

Parameters
  • old_name (str) – Old name of function or class. Printed in warning.

  • new_name (str) – New name of function or class. Printed in warning.

  • function_or_class (Callable) – Function or class to wrap.

Returns

Callable – Wrapped function/class.

fannypack.utils.deprecation_wrapper(message: str, function_or_class: Callable) → Callable

Creates a wrapper for a deprecated function or class. Prints a warning the first time a function or class is called.

Parameters
  • message (str) – Warning message.

  • function_or_class (Callable) – Function or class to wrap.

Returns

Callable – Wrapped function/class.

Debugging

fannypack.utils.pdb_safety_net()

Attaches a “safety net” for unexpected errors in a Python script.

When called, PDB will be automatically opened when either (a) the user hits Ctrl+C or (b) we encounter an uncaught exception. Helpful for bypassing minor errors, diagnosing problems, and rescuing unsaved models.