Source code for jaxlie.manifold._deltas

"""Helpers for recursively applying tangent-space deltas."""

from typing import Any, Callable, TypeVar, Union, cast, overload

import jax
import numpy as onp
from jax import numpy as jnp

from .. import hints
from .._base import MatrixLieGroup
from .._se2 import SE2
from .._se3 import SE3
from .._so2 import SO2
from .._so3 import SO3
from . import _tree_utils

PytreeType = TypeVar("PytreeType")
GroupType = TypeVar("GroupType", bound=MatrixLieGroup)
CallableType = TypeVar("CallableType", bound=Callable)


def _naive_auto_vmap(f: CallableType) -> CallableType:
    def inner(*args, **kwargs):
        batch_axes = None
        for arg in args + tuple(kwargs.values()):
            if isinstance(arg, MatrixLieGroup):
                if batch_axes is None:
                    batch_axes = arg.get_batch_axes()
                else:
                    assert arg.get_batch_axes() == batch_axes
        assert batch_axes is not None

        f_vmapped: Callable = f
        for i in range(len(batch_axes)):
            f_vmapped = jax.vmap(f_vmapped)
        return f_vmapped(*args, **kwargs)

    return inner  # type: ignore


@_naive_auto_vmap
def _rplus(transform: GroupType, delta: jax.Array) -> GroupType:
    assert isinstance(transform, MatrixLieGroup)
    assert isinstance(delta, (jax.Array, onp.ndarray))
    return transform @ type(transform).exp(delta)


@overload
def rplus(
    transform: GroupType,
    delta: hints.Array,
) -> GroupType: ...


@overload
def rplus(
    transform: PytreeType,
    delta: _tree_utils.TangentPytree,
) -> PytreeType: ...


# Using our typevars in the overloaded signature will cause errors.
[docs] def rplus( transform: Union[MatrixLieGroup, Any], delta: Union[hints.Array, Any], ) -> Union[MatrixLieGroup, Any]: """Manifold right plus. Computes `T' = T @ exp(delta)`. Supports pytrees containing Lie group instances recursively; simple Euclidean addition will be performed for all other arrays. """ return _tree_utils._map_group_trees(_rplus, jnp.add, transform, delta)
@_naive_auto_vmap def _rminus(a: GroupType, b: GroupType) -> jax.Array: assert isinstance(a, MatrixLieGroup) and isinstance(b, MatrixLieGroup) return (a.inverse() @ b).log() @overload def rminus(a: GroupType, b: GroupType) -> jax.Array: ... @overload def rminus(a: PytreeType, b: PytreeType) -> _tree_utils.TangentPytree: ... # Using our typevars in the overloaded signature will cause errors.
[docs] def rminus( a: Union[MatrixLieGroup, Any], b: Union[MatrixLieGroup, Any] ) -> Union[jax.Array, _tree_utils.TangentPytree]: """Manifold right minus. Computes `delta = T_ab.log() = (T_wa.inverse() @ T_wb).log()`. Supports pytrees containing Lie group instances recursively; simple Euclidean subtraction will be performed for all other arrays. """ return _tree_utils._map_group_trees(_rminus, jnp.subtract, a, b)
[docs] @jax.jit def rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) -> jax.Array: """Analytical Jacobians for `jaxlie.manifold.rplus()`, linearized around a zero local delta. Mostly useful for reducing JIT compile times for tangent-space optimization. Equivalent to -- ``` def rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) -> jax.Array: # Since transform objects are pytree containers, note that `jacfwd` returns a # transformation object itself and that the Jacobian terms corresponding to the # parameters are grabbed explicitly. return jax.jacfwd( jaxlie.manifold.rplus, # Args are (transform, delta) argnums=1, # Jacobian wrt delta )(transform, onp.zeros(transform.tangent_dim)).parameters() ``` Args: transform: Transform to linearize around. Returns: Jacobian. Shape should be `(Group.parameters_dim, Group.tangent_dim)`. """ if type(transform) is SO2: # Jacobian row indices: cos, sin # Jacobian col indices: theta transform_so2 = cast(SO2, transform) J = jnp.zeros((2, 1)) cos, sin = transform_so2.unit_complex J = J.at[0].set(-sin).at[1].set(cos) elif type(transform) is SE2: # Jacobian row indices: cos, sin, x, y # Jacobian col indices: vx, vy, omega transform_se2 = cast(SE2, transform) J = jnp.zeros((4, 3)) # Translation terms. J = J.at[2:, :2].set(transform_se2.rotation().as_matrix()) # Rotation terms. J = J.at[:2, 2:3].set( rplus_jacobian_parameters_wrt_delta(transform_se2.rotation()) ) elif type(transform) is SO3: # Jacobian row indices: qw, qx, qy, qz # Jacobian col indices: omega x, omega y, omega z transform_so3 = cast(SO3, transform) w, x, y, z = transform_so3.wxyz _unused_neg_w, neg_x, neg_y, neg_z = -transform_so3.wxyz J = ( jnp.array( [ [neg_x, neg_y, neg_z], [w, neg_z, y], [z, w, neg_x], [neg_y, x, w], ] ) / 2.0 ) elif type(transform) is SE3: # Jacobian row indices: qw, qx, qy, qz, x, y, z # Jacobian col indices: vx, vy, vz, omega x, omega y, omega z transform_se3 = cast(SE3, transform) J = jnp.zeros((7, 6)) # Translation terms. J = J.at[4:, :3].set(transform_se3.rotation().as_matrix()) # Rotation terms. J = J.at[:4, 3:6].set( rplus_jacobian_parameters_wrt_delta(transform_se3.rotation()) ) else: assert False, f"Unsupported type: {type(transform)}" assert J.shape == (transform.parameters_dim, transform.tangent_dim) return J