jaxlie.manifold._backprop

Module Contents

Functions

zero_tangents(pytree)

Replace all values in a Pytree with zero vectors on the corresponding tangent

grad(…)

Same as jax.grad, but computes gradients of Lie groups with respect to

value_and_grad(…)

Same as jax.value_and_grad, but computes gradients of Lie groups with respect to

Attributes

AxisName

P

jaxlie.manifold._backprop.zero_tangents(pytree)[source]

Replace all values in a Pytree with zero vectors on the corresponding tangent spaces.

Parameters:

pytree (Any) –

Return type:

jaxlie.manifold._tree_utils.TangentPytree

jaxlie.manifold._backprop.AxisName
jaxlie.manifold._backprop.P
jaxlie.manifold._backprop.grad(fun: Callable[P, Any], argnums: int = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()) Callable[P, jaxlie.manifold._tree_utils.TangentPytree][source]
jaxlie.manifold._backprop.grad(fun: Callable[P, Any], argnums: Sequence[int], has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()) Callable[P, Tuple[jaxlie.manifold._tree_utils.TangentPytree, Ellipsis]]

Same as jax.grad, but computes gradients of Lie groups with respect to tangent spaces.

jaxlie.manifold._backprop.value_and_grad(fun: Callable[P, Any], argnums: int = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()) Callable[P, Tuple[Any, jaxlie.manifold._tree_utils.TangentPytree]][source]
jaxlie.manifold._backprop.value_and_grad(fun: Callable[P, Any], argnums: Sequence[int], has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()) Callable[P, Tuple[Any, Tuple[jaxlie.manifold._tree_utils.TangentPytree, Ellipsis]]]

Same as jax.value_and_grad, but computes gradients of Lie groups with respect to tangent spaces.