jaxlie.manifold

Package Contents

Functions

grad(…)

Same as jax.grad, but computes gradients of Lie groups with respect to

value_and_grad(…)

Same as jax.value_and_grad, but computes gradients of Lie groups with respect to

zero_tangents(pytree)

Replace all values in a Pytree with zero vectors on the corresponding tangent

rminus(…)

Manifold right minus. Computes

rplus(…)

Manifold right plus. Computes T' = T @ exp(delta).

rplus_jacobian_parameters_wrt_delta(transform)

Analytical Jacobians for jaxlie.manifold.rplus(), linearized around a zero

normalize_all(pytree)

Call .normalize() on each Lie group instance in a pytree.

jaxlie.manifold.grad(fun: Callable[P, Any], argnums: int = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()) Callable[P, jaxlie.manifold._tree_utils.TangentPytree][source]
jaxlie.manifold.grad(fun: Callable[P, Any], argnums: Sequence[int], has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()) Callable[P, Tuple[jaxlie.manifold._tree_utils.TangentPytree, Ellipsis]]

Same as jax.grad, but computes gradients of Lie groups with respect to tangent spaces.

jaxlie.manifold.value_and_grad(fun: Callable[P, Any], argnums: int = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()) Callable[P, Tuple[Any, jaxlie.manifold._tree_utils.TangentPytree]][source]
jaxlie.manifold.value_and_grad(fun: Callable[P, Any], argnums: Sequence[int], has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()) Callable[P, Tuple[Any, Tuple[jaxlie.manifold._tree_utils.TangentPytree, Ellipsis]]]

Same as jax.value_and_grad, but computes gradients of Lie groups with respect to tangent spaces.

jaxlie.manifold.zero_tangents(pytree)[source]

Replace all values in a Pytree with zero vectors on the corresponding tangent spaces.

Parameters:

pytree (Any) –

Return type:

jaxlie.manifold._tree_utils.TangentPytree

jaxlie.manifold.rminus(a: GroupType, b: GroupType) jax.Array[source]
jaxlie.manifold.rminus(a: PytreeType, b: PytreeType) jaxlie.manifold._tree_utils.TangentPytree

Manifold right minus. Computes delta = T_ab.log() = (T_wa.inverse() @ T_wb).log().

Supports pytrees containing Lie group instances recursively; simple Euclidean subtraction will be performed for all other arrays.

jaxlie.manifold.rplus(transform: GroupType, delta: jaxlie.hints.Array) GroupType[source]
jaxlie.manifold.rplus(transform: PytreeType, delta: jaxlie.manifold._tree_utils.TangentPytree) PytreeType

Manifold right plus. Computes T' = T @ exp(delta).

Supports pytrees containing Lie group instances recursively; simple Euclidean addition will be performed for all other arrays.

jaxlie.manifold.rplus_jacobian_parameters_wrt_delta(transform)[source]

Analytical Jacobians for jaxlie.manifold.rplus(), linearized around a zero local delta.

Mostly useful for reducing JIT compile times for tangent-space optimization.

Equivalent to –

def rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) -> jax.Array:
    # Since transform objects are pytree containers, note that `jacfwd` returns a
    # transformation object itself and that the Jacobian terms corresponding to the
    # parameters are grabbed explicitly.
    return jax.jacfwd(
        jaxlie.manifold.rplus,  # Args are (transform, delta)
        argnums=1,  # Jacobian wrt delta
    )(transform, onp.zeros(transform.tangent_dim)).parameters()
Parameters:

transform (jaxlie.MatrixLieGroup) – Transform to linearize around.

Returns:

Jacobian. Shape should be (Group.parameters_dim, Group.tangent_dim).

Return type:

jax.Array

jaxlie.manifold.normalize_all(pytree)[source]

Call .normalize() on each Lie group instance in a pytree.

Results in a naive projection of each group instance to its respective manifold.

Parameters:

pytree (PytreeType) –

Return type:

PytreeType