Source code for jaxlie.manifold._backprop

from __future__ import annotations

from typing import Any, Callable, Sequence, Tuple, Union, overload

import jax
from jax import numpy as jnp
from typing_extensions import ParamSpec

from .._base import MatrixLieGroup
from . import _deltas, _tree_utils


[docs] def zero_tangents(pytree: Any) -> _tree_utils.TangentPytree: """Replace all values in a Pytree with zero vectors on the corresponding tangent spaces.""" def tangent_zero(t: MatrixLieGroup) -> jax.Array: return jnp.zeros(t.get_batch_axes() + (t.tangent_dim,)) return _tree_utils._map_group_trees( tangent_zero, lambda array: jnp.zeros_like(array), pytree, )
AxisName = Any P = ParamSpec("P") @overload def grad( fun: Callable[P, Any], argnums: int = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = (), ) -> Callable[P, _tree_utils.TangentPytree]: ... @overload def grad( fun: Callable[P, Any], argnums: Sequence[int], has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = (), ) -> Callable[P, Tuple[_tree_utils.TangentPytree, ...]]: ...
[docs] def grad( fun: Callable[P, Any], argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = (), ): """Same as `jax.grad`, but computes gradients of Lie groups with respect to tangent spaces.""" compute_value_and_grad = value_and_grad( fun=fun, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic, allow_int=allow_int, reduce_axes=reduce_axes, ) def grad_fun(*args, **kwargs): ret = compute_value_and_grad(*args, **kwargs) if has_aux: return ret[1], ret[0][1] else: return ret[1] return grad_fun
@overload def value_and_grad( fun: Callable[P, Any], argnums: int = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = (), ) -> Callable[P, Tuple[Any, _tree_utils.TangentPytree]]: ... @overload def value_and_grad( fun: Callable[P, Any], argnums: Sequence[int], has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = (), ) -> Callable[P, Tuple[Any, Tuple[_tree_utils.TangentPytree, ...]]]: ...
[docs] def value_and_grad( fun: Callable[P, Any], argnums: Union[int, Sequence[int]] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, reduce_axes: Sequence[AxisName] = (), ): """Same as `jax.value_and_grad`, but computes gradients of Lie groups with respect to tangent spaces.""" def wrapped_grad(*args, **kwargs): def tangent_fun(*tangent_args, **tangent_kwargs): return fun( # type: ignore *_deltas.rplus(args, tangent_args), **_deltas.rplus(kwargs, tangent_kwargs), ) # Put arguments onto tangent space. tangent_args = map(zero_tangents, args) tangent_kwargs = {k: zero_tangents(v) for k, v in kwargs.items()} return jax.value_and_grad( fun=tangent_fun, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic, allow_int=allow_int, reduce_axes=reduce_axes, )(*tangent_args, **tangent_kwargs) return wrapped_grad # type: ignore