jaxlie.manifold

Package Contents

Functions

rminus(a: T, b: T) → types.TangentVector

Manifold right minus.

rplus(transform: T, delta: types.TangentVector) → T

Manifold right plus.

rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) → jnp.ndarray

Analytical Jacobians for jaxlie.manifold.rplus(), linearized around a zero

jaxlie.manifold.rminus(a: T, b: T) → types.TangentVector[source]

Manifold right minus.

Computes delta = (T_wa.inverse() @ T_wb).log().

Parameters
  • a (T) – T_wa

  • b (T) – T_wb

Returns

types.TangentVectorT_ab.log()

jaxlie.manifold.rplus(transform: T, delta: types.TangentVector) → T[source]

Manifold right plus.

Computes T_wb = T_wa @ exp(delta).

Parameters
  • transform (T) – T_wa

  • delta (types.TangentVector) – T_ab.log()

Returns

TT_wb

jaxlie.manifold.rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) → jnp.ndarray[source]

Analytical Jacobians for jaxlie.manifold.rplus(), linearized around a zero local delta.

Useful for on-manifold optimization.

Equivalent to –

def rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) -> jnp.ndarray:
    # Since transform objects are PyTree containers, note that `jacfwd` returns a
    # transformation object itself and that the Jacobian terms corresponding to the
    # parameters are grabbed explicitly.
    return jax.jacfwd(
        jaxlie.manifold.rplus,  # Args are (transform, delta)
        argnums=1,  # Jacobian wrt delta
    )(transform, onp.zeros(transform.tangent_dim)).parameters
Parameters

transform (T) – transform

Returns

jnp.ndarray – Jacobian. Shape should be (Group.parameters_dim, Group.tangent_dim).