:py:mod:`jaxlie.manifold` ========================= .. py:module:: jaxlie.manifold Package Contents ---------------- Functions ~~~~~~~~~ .. autoapisummary:: jaxlie.manifold.grad jaxlie.manifold.value_and_grad jaxlie.manifold.zero_tangents jaxlie.manifold.rminus jaxlie.manifold.rplus jaxlie.manifold.rplus_jacobian_parameters_wrt_delta jaxlie.manifold.normalize_all .. py:function:: grad(fun: Callable[P, Any], argnums: int = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()) -> Callable[P, jaxlie.manifold._tree_utils.TangentPytree] grad(fun: Callable[P, Any], argnums: Sequence[int], has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()) -> Callable[P, Tuple[jaxlie.manifold._tree_utils.TangentPytree, Ellipsis]] Same as ``jax.grad``\ , but computes gradients of Lie groups with respect to tangent spaces. .. py:function:: value_and_grad(fun: Callable[P, Any], argnums: int = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()) -> Callable[P, Tuple[Any, jaxlie.manifold._tree_utils.TangentPytree]] value_and_grad(fun: Callable[P, Any], argnums: Sequence[int], has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()) -> Callable[P, Tuple[Any, Tuple[jaxlie.manifold._tree_utils.TangentPytree, Ellipsis]]] Same as ``jax.value_and_grad``\ , but computes gradients of Lie groups with respect to tangent spaces. .. py:function:: zero_tangents(pytree) Replace all values in a Pytree with zero vectors on the corresponding tangent spaces. .. py:function:: rminus(a: GroupType, b: GroupType) -> jax.Array rminus(a: PytreeType, b: PytreeType) -> jaxlie.manifold._tree_utils.TangentPytree Manifold right minus. Computes ``delta = T_ab.log() = (T_wa.inverse() @ T_wb).log()``. Supports pytrees containing Lie group instances recursively; simple Euclidean subtraction will be performed for all other arrays. .. py:function:: rplus(transform: GroupType, delta: jaxlie.hints.Array) -> GroupType rplus(transform: PytreeType, delta: jaxlie.manifold._tree_utils.TangentPytree) -> PytreeType Manifold right plus. Computes ``T' = T @ exp(delta)``. Supports pytrees containing Lie group instances recursively; simple Euclidean addition will be performed for all other arrays. .. py:function:: rplus_jacobian_parameters_wrt_delta(transform) Analytical Jacobians for ``jaxlie.manifold.rplus()``\ , linearized around a zero local delta. Mostly useful for reducing JIT compile times for tangent-space optimization. Equivalent to -- .. code-block:: def rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) -> jax.Array: # Since transform objects are pytree containers, note that `jacfwd` returns a # transformation object itself and that the Jacobian terms corresponding to the # parameters are grabbed explicitly. return jax.jacfwd( jaxlie.manifold.rplus, # Args are (transform, delta) argnums=1, # Jacobian wrt delta )(transform, onp.zeros(transform.tangent_dim)).parameters() :param transform: Transform to linearize around. :returns: Jacobian. Shape should be ``(Group.parameters_dim, Group.tangent_dim)``. .. py:function:: normalize_all(pytree) Call ``.normalize()`` on each Lie group instance in a pytree. Results in a naive projection of each group instance to its respective manifold.