jaxlie.manifold._deltas
Helpers for recursively applying tangent-space deltas.
Module Contents
Functions
|
Manifold right plus. Computes |
|
Manifold right minus. Computes |
|
Analytical Jacobians for |
Attributes
- jaxlie.manifold._deltas.PytreeType
- jaxlie.manifold._deltas.GroupType
- jaxlie.manifold._deltas.rplus(transform: GroupType, delta: jaxlie.hints.Array) GroupType [source]
- jaxlie.manifold._deltas.rplus(transform: PytreeType, delta: jaxlie.manifold._tree_utils.TangentPytree) PytreeType
Manifold right plus. Computes
T' = T @ exp(delta)
.Supports pytrees containing Lie group instances recursively; simple Euclidean addition will be performed for all other arrays.
- jaxlie.manifold._deltas.rminus(a: GroupType, b: GroupType) jax.Array [source]
- jaxlie.manifold._deltas.rminus(a: PytreeType, b: PytreeType) jaxlie.manifold._tree_utils.TangentPytree
Manifold right minus. Computes
delta = T_ab.log() = (T_wa.inverse() @ T_wb).log()
.Supports pytrees containing Lie group instances recursively; simple Euclidean subtraction will be performed for all other arrays.
- jaxlie.manifold._deltas.rplus_jacobian_parameters_wrt_delta(transform)[source]
Analytical Jacobians for
jaxlie.manifold.rplus()
, linearized around a zero local delta.Mostly useful for reducing JIT compile times for tangent-space optimization.
Equivalent to –
def rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) -> jax.Array: # Since transform objects are pytree containers, note that `jacfwd` returns a # transformation object itself and that the Jacobian terms corresponding to the # parameters are grabbed explicitly. return jax.jacfwd( jaxlie.manifold.rplus, # Args are (transform, delta) argnums=1, # Jacobian wrt delta )(transform, onp.zeros(transform.tangent_dim)).parameters()
- Parameters:
transform (jaxlie.MatrixLieGroup) – Transform to linearize around.
- Returns:
Jacobian. Shape should be
(Group.parameters_dim, Group.tangent_dim)
.- Return type:
jax.Array