Potpourri
Miscellaneous utilities.
Math
- fannypack.utils.cholesky_inverse(u: torch.Tensor, upper: bool = False) torch.Tensor
Alternative to
torch.cholesky_inverse(), with support for batch dimensions.Relevant issue tracker: https://github.com/pytorch/pytorch/issues/7500
- Parameters
u (torch.Tensor) – Triangular Cholesky factor. Shape should be
(*, N, N).upper (bool, optional) – Whether to consider the Cholesky factor as a lower or upper triangular matrix.
- Returns
torch.Tensor
- fannypack.utils.tril_inverse(tril_matrix: torch.Tensor) torch.Tensor
Invert a lower-triangular matrix.
- Parameters
tril_matrix (torch.Tensor) – Lower-triangular matrix to invert. Shape should be
(*, matrix_dim, matrix_dim).- Returns
torch.Tensor – Inverted matrix. Shape should be
(*, matrix_dim, matrix_dim).
- fannypack.utils.cholupdate(L: torch.Tensor, x: torch.Tensor, weight: Optional[Union[torch.Tensor, float]] = None) torch.Tensor
Batched rank-1 Cholesky update.
Computes the Cholesky decomposition of
RR^T + weight * xx^T.- Parameters
L (torch.Tensor) – Lower triangular Cholesky decomposition of a PSD matrix. Shape should be
(*, matrix_dim, matrix_dim).x (torch.Tensor) – Rank-1 update vector. Shape should be
(*, matrix_dim).weight (torch.Tensor or float, optional) – Set to -1 for “downdate”. Shape must be broadcastable with
(*, matrix_dim).
- Returns
torch.Tensor – New L matrix. Same shape as L.
- fannypack.utils.tril_count_from_matrix_dim(matrix_dim: int)
Computes the number of lower triangular terms in a square matrix of a given dimension
(matrix_dim, matrix_dim).- Parameters
matrix_dim (int) – Dimension of square matrix.
- Returns
int – Count of lower-triangular terms.
- fannypack.utils.matrix_dim_from_tril_count(tril_count: int)
Computes the dimension of a lower triangular square matrix given a count of its lower-triangular components.
- Parameters
tril_count (int) – Count of lower-triangular terms.
- Returns
int – Dimension of square matrix.
- fannypack.utils.tril_from_vector(lower_vector: torch.Tensor) torch.Tensor
Computes lower-triangular square matrices from a flattened vector of nonzero terms. Supports arbitrary batch dimensions.
- Parameters
lower_vector (torch.Tensor) – Vectors containing the nonzero terms of a square lower-triangular matrix. Shape should be
(*, tril_count).- Returns
torch.Tensor – Square matrices. Shape should be
(*, matrix_dim, matrix_dim).
- fannypack.utils.vector_from_tril(tril_matrix: torch.Tensor) torch.Tensor
Retrieves the lower triangular terms of square matrices as vectors. Supports arbitrary batch dimensions.
- Parameters
tril_matrix (torch.Tensor) – Square matrices. Shape should be
(*, matrix_dim, matrix_dim)- Returns
torch.Tensor – Flattened vectors. Shape should be
(*, tril_count).
- fannypack.utils.gaussian_log_prob(mean: torch.Tensor, covariance: torch.Tensor, value: torch.Tensor) torch.Tensor
Computes log probabilities under multivariate Gaussian distributions, with support for arbitrary batch axes.
Naive version of…
torch.distributions.MultivariateNormal( mean, covariance ).log_prob(value)
that avoids some Cholesky-related CUDA errors. https://discuss.pytorch.org/t/cuda-illegal-memory-access-when-using-batched-torch-cholesky/51624
Stolen from @alberthli/@wuphilipp.
- Parameters
mean (torch.Tensor) – Means vectors. Shape should be
(*, D).covariance (torch.Tensor) – Covariances matrices. Shape should be
(*, D, D).value (torch.Tensor) – State vectors. Shape should be
(*, D).
- Returns
torch.Tensor – Batched log probabilities. Shape should be
(*).
- fannypack.utils.quadratic_matmul(x: torch.Tensor, A: torch.Tensor) torch.Tensor
Computes \(x^\top A x\), with support for arbitrary batch axes.
Stolen from @alberthli/@wuphilipp.
- Parameters
x (torch.Tensor) – Vectors. Shape should be
(*, D).A (torch.Tensor) – Matrices. Shape should be
(*, D, D).
- Returns
torch.Tensor – Batched output of multiplication. Shape should be
(*).
Module Freezing
- fannypack.utils.freeze_module(module: torch.nn.modules.module.Module, recurse: bool = True) None
Freeze the weights of a PyTorch module by setting the
requires_gradattributes of enclosed parameters toFalse.- Parameters
module (torch.nn.Module) – Module to freeze.
recurse (bool, optional) – If True, then recursively freezes children. Otherwise, only freezes immediate parameters.
- fannypack.utils.unfreeze_module(module: torch.nn.modules.module.Module, recurse: bool = True) None
Unfreeze the weights of a PyTorch module, which needs to have been frozen with
fannypack.utils.freeze_module(). Restores all original valuesrequires_gradvalues.- Parameters
module (torch.nn.Module) – Module to unfreeze.
recurse (bool, optional) – If True, then recursively unfreezes children. Otherwise, only unfreezes immediate parameters.
Deprecation Helpers
- fannypack.utils.new_name_wrapper(old_name: str, new_name: str, function_or_class: fannypack.utils._deprecation.CallableType) fannypack.utils._deprecation.CallableType
Creates a wrapper for a renamed function or class. Prints a warning the first time a function or class is called with the old name.
- Parameters
old_name (str) – Old name of function or class. Printed in warning.
new_name (str) – New name of function or class. Printed in warning.
function_or_class (CallableType) – Function or class to wrap.
- Returns
CallableType – Wrapped function/class.
- fannypack.utils.deprecation_wrapper(message: str, function_or_class: fannypack.utils._deprecation.CallableType) fannypack.utils._deprecation.CallableType
Creates a wrapper for a deprecated function or class. Prints a warning the first time a function or class is called.
- Parameters
message (str) – Warning message.
function_or_class (CallableType) – Function or class to wrap.
- Returns
CallableType – Wrapped function/class.
Debugging
- fannypack.utils.pdb_safety_net()
Attaches a “safety net” for unexpected errors in a Python script.
When called, PDB will be automatically opened when either (a) the user hits Ctrl+C or (b) we encounter an uncaught exception. Helpful for bypassing minor errors, diagnosing problems, and rescuing unsaved models.
- fannypack.utils.get_git_commit_hash(path: str = './') str
Returns the current Git commit hash.
- Parameters
path (str, optional) – Path to check. Defaults to ‘./’.