Potpourri

Miscellaneous utilities.

Math

fannypack.utils.cholesky_inverse(u: torch.Tensor, upper: bool = False) torch.Tensor

Alternative to torch.cholesky_inverse(), with support for batch dimensions.

Relevant issue tracker: https://github.com/pytorch/pytorch/issues/7500

Parameters
  • u (torch.Tensor) – Triangular Cholesky factor. Shape should be (*, N, N).

  • upper (bool, optional) – Whether to consider the Cholesky factor as a lower or upper triangular matrix.

Returns

torch.Tensor

fannypack.utils.tril_inverse(tril_matrix: torch.Tensor) torch.Tensor

Invert a lower-triangular matrix.

Parameters

tril_matrix (torch.Tensor) – Lower-triangular matrix to invert. Shape should be (*, matrix_dim, matrix_dim).

Returns

torch.Tensor – Inverted matrix. Shape should be (*, matrix_dim, matrix_dim).

fannypack.utils.cholupdate(L: torch.Tensor, x: torch.Tensor, weight: Optional[Union[torch.Tensor, float]] = None) torch.Tensor

Batched rank-1 Cholesky update.

Computes the Cholesky decomposition of RR^T + weight * xx^T.

Parameters
  • L (torch.Tensor) – Lower triangular Cholesky decomposition of a PSD matrix. Shape should be (*, matrix_dim, matrix_dim).

  • x (torch.Tensor) – Rank-1 update vector. Shape should be (*, matrix_dim).

  • weight (torch.Tensor or float, optional) – Set to -1 for “downdate”. Shape must be broadcastable with (*, matrix_dim).

Returns

torch.Tensor – New L matrix. Same shape as L.

fannypack.utils.tril_count_from_matrix_dim(matrix_dim: int)

Computes the number of lower triangular terms in a square matrix of a given dimension (matrix_dim, matrix_dim).

Parameters

matrix_dim (int) – Dimension of square matrix.

Returns

int – Count of lower-triangular terms.

fannypack.utils.matrix_dim_from_tril_count(tril_count: int)

Computes the dimension of a lower triangular square matrix given a count of its lower-triangular components.

Parameters

tril_count (int) – Count of lower-triangular terms.

Returns

int – Dimension of square matrix.

fannypack.utils.tril_from_vector(lower_vector: torch.Tensor) torch.Tensor

Computes lower-triangular square matrices from a flattened vector of nonzero terms. Supports arbitrary batch dimensions.

Parameters

lower_vector (torch.Tensor) – Vectors containing the nonzero terms of a square lower-triangular matrix. Shape should be (*, tril_count).

Returns

torch.Tensor – Square matrices. Shape should be (*, matrix_dim, matrix_dim).

fannypack.utils.vector_from_tril(tril_matrix: torch.Tensor) torch.Tensor

Retrieves the lower triangular terms of square matrices as vectors. Supports arbitrary batch dimensions.

Parameters

tril_matrix (torch.Tensor) – Square matrices. Shape should be (*, matrix_dim, matrix_dim)

Returns

torch.Tensor – Flattened vectors. Shape should be (*, tril_count).

fannypack.utils.gaussian_log_prob(mean: torch.Tensor, covariance: torch.Tensor, value: torch.Tensor) torch.Tensor

Computes log probabilities under multivariate Gaussian distributions, with support for arbitrary batch axes.

Naive version of…

torch.distributions.MultivariateNormal(
    mean, covariance
).log_prob(value)

that avoids some Cholesky-related CUDA errors. https://discuss.pytorch.org/t/cuda-illegal-memory-access-when-using-batched-torch-cholesky/51624

Stolen from @alberthli/@wuphilipp.

Parameters
  • mean (torch.Tensor) – Means vectors. Shape should be (*, D).

  • covariance (torch.Tensor) – Covariances matrices. Shape should be (*, D, D).

  • value (torch.Tensor) – State vectors. Shape should be (*, D).

Returns

torch.Tensor – Batched log probabilities. Shape should be (*).

fannypack.utils.quadratic_matmul(x: torch.Tensor, A: torch.Tensor) torch.Tensor

Computes \(x^\top A x\), with support for arbitrary batch axes.

Stolen from @alberthli/@wuphilipp.

Parameters
  • x (torch.Tensor) – Vectors. Shape should be (*, D).

  • A (torch.Tensor) – Matrices. Shape should be (*, D, D).

Returns

torch.Tensor – Batched output of multiplication. Shape should be (*).

Module Freezing

fannypack.utils.freeze_module(module: torch.nn.modules.module.Module, recurse: bool = True) None

Freeze the weights of a PyTorch module by setting the requires_grad attributes of enclosed parameters to False.

Parameters
  • module (torch.nn.Module) – Module to freeze.

  • recurse (bool, optional) – If True, then recursively freezes children. Otherwise, only freezes immediate parameters.

fannypack.utils.unfreeze_module(module: torch.nn.modules.module.Module, recurse: bool = True) None

Unfreeze the weights of a PyTorch module, which needs to have been frozen with fannypack.utils.freeze_module(). Restores all original values requires_grad values.

Parameters
  • module (torch.nn.Module) – Module to unfreeze.

  • recurse (bool, optional) – If True, then recursively unfreezes children. Otherwise, only unfreezes immediate parameters.

Deprecation Helpers

fannypack.utils.new_name_wrapper(old_name: str, new_name: str, function_or_class: fannypack.utils._deprecation.CallableType) fannypack.utils._deprecation.CallableType

Creates a wrapper for a renamed function or class. Prints a warning the first time a function or class is called with the old name.

Parameters
  • old_name (str) – Old name of function or class. Printed in warning.

  • new_name (str) – New name of function or class. Printed in warning.

  • function_or_class (CallableType) – Function or class to wrap.

Returns

CallableType – Wrapped function/class.

fannypack.utils.deprecation_wrapper(message: str, function_or_class: fannypack.utils._deprecation.CallableType) fannypack.utils._deprecation.CallableType

Creates a wrapper for a deprecated function or class. Prints a warning the first time a function or class is called.

Parameters
  • message (str) – Warning message.

  • function_or_class (CallableType) – Function or class to wrap.

Returns

CallableType – Wrapped function/class.

Debugging

fannypack.utils.pdb_safety_net()

Attaches a “safety net” for unexpected errors in a Python script.

When called, PDB will be automatically opened when either (a) the user hits Ctrl+C or (b) we encounter an uncaught exception. Helpful for bypassing minor errors, diagnosing problems, and rescuing unsaved models.

fannypack.utils.get_git_commit_hash(path: str = './') str

Returns the current Git commit hash.

Parameters

path (str, optional) – Path to check. Defaults to ‘./’.