Potpourri
Miscellaneous utilities.
Math
- fannypack.utils.cholesky_inverse(u: torch.Tensor, upper: bool = False) torch.Tensor
Alternative to
torch.cholesky_inverse()
, with support for batch dimensions.Relevant issue tracker: https://github.com/pytorch/pytorch/issues/7500
- Parameters
u (torch.Tensor) – Triangular Cholesky factor. Shape should be
(*, N, N)
.upper (bool, optional) – Whether to consider the Cholesky factor as a lower or upper triangular matrix.
- Returns
torch.Tensor
- fannypack.utils.tril_inverse(tril_matrix: torch.Tensor) torch.Tensor
Invert a lower-triangular matrix.
- Parameters
tril_matrix (torch.Tensor) – Lower-triangular matrix to invert. Shape should be
(*, matrix_dim, matrix_dim)
.- Returns
torch.Tensor – Inverted matrix. Shape should be
(*, matrix_dim, matrix_dim)
.
- fannypack.utils.cholupdate(L: torch.Tensor, x: torch.Tensor, weight: Optional[Union[torch.Tensor, float]] = None) torch.Tensor
Batched rank-1 Cholesky update.
Computes the Cholesky decomposition of
RR^T + weight * xx^T
.- Parameters
L (torch.Tensor) – Lower triangular Cholesky decomposition of a PSD matrix. Shape should be
(*, matrix_dim, matrix_dim)
.x (torch.Tensor) – Rank-1 update vector. Shape should be
(*, matrix_dim)
.weight (torch.Tensor or float, optional) – Set to -1 for “downdate”. Shape must be broadcastable with
(*, matrix_dim)
.
- Returns
torch.Tensor – New L matrix. Same shape as L.
- fannypack.utils.tril_count_from_matrix_dim(matrix_dim: int)
Computes the number of lower triangular terms in a square matrix of a given dimension
(matrix_dim, matrix_dim)
.- Parameters
matrix_dim (int) – Dimension of square matrix.
- Returns
int – Count of lower-triangular terms.
- fannypack.utils.matrix_dim_from_tril_count(tril_count: int)
Computes the dimension of a lower triangular square matrix given a count of its lower-triangular components.
- Parameters
tril_count (int) – Count of lower-triangular terms.
- Returns
int – Dimension of square matrix.
- fannypack.utils.tril_from_vector(lower_vector: torch.Tensor) torch.Tensor
Computes lower-triangular square matrices from a flattened vector of nonzero terms. Supports arbitrary batch dimensions.
- Parameters
lower_vector (torch.Tensor) – Vectors containing the nonzero terms of a square lower-triangular matrix. Shape should be
(*, tril_count)
.- Returns
torch.Tensor – Square matrices. Shape should be
(*, matrix_dim, matrix_dim)
.
- fannypack.utils.vector_from_tril(tril_matrix: torch.Tensor) torch.Tensor
Retrieves the lower triangular terms of square matrices as vectors. Supports arbitrary batch dimensions.
- Parameters
tril_matrix (torch.Tensor) – Square matrices. Shape should be
(*, matrix_dim, matrix_dim)
- Returns
torch.Tensor – Flattened vectors. Shape should be
(*, tril_count)
.
- fannypack.utils.gaussian_log_prob(mean: torch.Tensor, covariance: torch.Tensor, value: torch.Tensor) torch.Tensor
Computes log probabilities under multivariate Gaussian distributions, with support for arbitrary batch axes.
Naive version of…
torch.distributions.MultivariateNormal( mean, covariance ).log_prob(value)
that avoids some Cholesky-related CUDA errors. https://discuss.pytorch.org/t/cuda-illegal-memory-access-when-using-batched-torch-cholesky/51624
Stolen from @alberthli/@wuphilipp.
- Parameters
mean (torch.Tensor) – Means vectors. Shape should be
(*, D)
.covariance (torch.Tensor) – Covariances matrices. Shape should be
(*, D, D)
.value (torch.Tensor) – State vectors. Shape should be
(*, D)
.
- Returns
torch.Tensor – Batched log probabilities. Shape should be
(*)
.
- fannypack.utils.quadratic_matmul(x: torch.Tensor, A: torch.Tensor) torch.Tensor
Computes \(x^\top A x\), with support for arbitrary batch axes.
Stolen from @alberthli/@wuphilipp.
- Parameters
x (torch.Tensor) – Vectors. Shape should be
(*, D)
.A (torch.Tensor) – Matrices. Shape should be
(*, D, D)
.
- Returns
torch.Tensor – Batched output of multiplication. Shape should be
(*)
.
Module Freezing
- fannypack.utils.freeze_module(module: torch.nn.modules.module.Module, recurse: bool = True) None
Freeze the weights of a PyTorch module by setting the
requires_grad
attributes of enclosed parameters toFalse
.- Parameters
module (torch.nn.Module) – Module to freeze.
recurse (bool, optional) – If True, then recursively freezes children. Otherwise, only freezes immediate parameters.
- fannypack.utils.unfreeze_module(module: torch.nn.modules.module.Module, recurse: bool = True) None
Unfreeze the weights of a PyTorch module, which needs to have been frozen with
fannypack.utils.freeze_module()
. Restores all original valuesrequires_grad
values.- Parameters
module (torch.nn.Module) – Module to unfreeze.
recurse (bool, optional) – If True, then recursively unfreezes children. Otherwise, only unfreezes immediate parameters.
Deprecation Helpers
- fannypack.utils.new_name_wrapper(old_name: str, new_name: str, function_or_class: fannypack.utils._deprecation.CallableType) fannypack.utils._deprecation.CallableType
Creates a wrapper for a renamed function or class. Prints a warning the first time a function or class is called with the old name.
- Parameters
old_name (str) – Old name of function or class. Printed in warning.
new_name (str) – New name of function or class. Printed in warning.
function_or_class (CallableType) – Function or class to wrap.
- Returns
CallableType – Wrapped function/class.
- fannypack.utils.deprecation_wrapper(message: str, function_or_class: fannypack.utils._deprecation.CallableType) fannypack.utils._deprecation.CallableType
Creates a wrapper for a deprecated function or class. Prints a warning the first time a function or class is called.
- Parameters
message (str) – Warning message.
function_or_class (CallableType) – Function or class to wrap.
- Returns
CallableType – Wrapped function/class.
Debugging
- fannypack.utils.pdb_safety_net()
Attaches a “safety net” for unexpected errors in a Python script.
When called, PDB will be automatically opened when either (a) the user hits Ctrl+C or (b) we encounter an uncaught exception. Helpful for bypassing minor errors, diagnosing problems, and rescuing unsaved models.
- fannypack.utils.get_git_commit_hash(path: str = './') str
Returns the current Git commit hash.
- Parameters
path (str, optional) – Path to check. Defaults to ‘./’.