Source code for jaxlie.manifold._deltas

"""Helpers for recursively applying tangent-space deltas."""

from typing import Any, TypeVar, Union, overload

import jax
import numpy as onp
from jax import numpy as jnp

from .. import hints
from .._base import MatrixLieGroup
from .._se2 import SE2
from .._se3 import SE3
from .._so2 import SO2
from .._so3 import SO3
from . import _tree_utils

PytreeType = TypeVar("PytreeType")
GroupType = TypeVar("GroupType", bound=MatrixLieGroup)


def _rplus(transform: GroupType, delta: jax.Array) -> GroupType:
    assert isinstance(transform, MatrixLieGroup)
    assert isinstance(delta, (jax.Array, onp.ndarray))
    return transform @ type(transform).exp(delta)


@overload
def rplus(
    transform: GroupType,
    delta: hints.Array,
) -> GroupType: ...


@overload
def rplus(
    transform: PytreeType,
    delta: _tree_utils.TangentPytree,
) -> PytreeType: ...


# Using our typevars in the overloaded signature will cause errors.
[docs] def rplus( transform: Union[MatrixLieGroup, Any], delta: Union[hints.Array, Any], ) -> Union[MatrixLieGroup, Any]: """Manifold right plus. Computes `T' = T @ exp(delta)`. Supports pytrees containing Lie group instances recursively; simple Euclidean addition will be performed for all other arrays. """ return _tree_utils._map_group_trees(_rplus, jnp.add, transform, delta)
def _rminus(a: GroupType, b: GroupType) -> jax.Array: assert isinstance(a, MatrixLieGroup) and isinstance(b, MatrixLieGroup) return (a.inverse() @ b).log() @overload def rminus(a: GroupType, b: GroupType) -> jax.Array: ... @overload def rminus(a: PytreeType, b: PytreeType) -> _tree_utils.TangentPytree: ... # Using our typevars in the overloaded signature will cause errors.
[docs] def rminus( a: Union[MatrixLieGroup, Any], b: Union[MatrixLieGroup, Any] ) -> Union[jax.Array, _tree_utils.TangentPytree]: """Manifold right minus. Computes `delta = T_ab.log() = (T_wa.inverse() @ T_wb).log()`. Supports pytrees containing Lie group instances recursively; simple Euclidean subtraction will be performed for all other arrays. """ return _tree_utils._map_group_trees(_rminus, jnp.subtract, a, b)
[docs] @jax.jit def rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) -> jax.Array: """Analytical Jacobians for `jaxlie.manifold.rplus()`, linearized around a zero local delta. Mostly useful for reducing JIT compile times for tangent-space optimization. Equivalent to -- ``` def rplus_jacobian_parameters_wrt_delta(transform: MatrixLieGroup) -> jax.Array: # Since transform objects are pytree containers, note that `jacfwd` returns a # transformation object itself and that the Jacobian terms corresponding to the # parameters are grabbed explicitly. return jax.jacfwd( jaxlie.manifold.rplus, # Args are (transform, delta) argnums=1, # Jacobian wrt delta )(transform, onp.zeros(transform.tangent_dim)).parameters() ``` Args: transform: Transform to linearize around. Returns: Jacobian. Shape should be `(Group.parameters_dim, Group.tangent_dim)`. """ if isinstance(transform, SO2): # Jacobian row indices: cos, sin # Jacobian col indices: theta J = jnp.zeros((*transform.get_batch_axes(), 2, 1)) cos, sin = jnp.moveaxis(transform.unit_complex, -1, 0) J = J.at[..., 0, 0].set(-sin).at[..., 1, 0].set(cos) elif isinstance(transform, SE2): # Jacobian row indices: cos, sin, x, y # Jacobian col indices: vx, vy, omega J = jnp.zeros((*transform.get_batch_axes(), 4, 3)) # Translation terms. J = J.at[..., 2:, :2].set(transform.rotation().as_matrix()) # Rotation terms. J = J.at[..., :2, 2:3].set( rplus_jacobian_parameters_wrt_delta(transform.rotation()) ) elif isinstance(transform, SO3): # Jacobian row indices: qw, qx, qy, qz # Jacobian col indices: omega x, omega y, omega z w, x, y, z = jnp.moveaxis(transform.wxyz, -1, 0) neg_x = -x neg_y = -y neg_z = -z J = ( jnp.stack( [ neg_x, neg_y, neg_z, w, neg_z, y, z, w, neg_x, neg_y, x, w, ], axis=-1, ).reshape((*transform.get_batch_axes(), 4, 3)) / 2.0 ) elif isinstance(transform, SE3): # Jacobian row indices: qw, qx, qy, qz, x, y, z # Jacobian col indices: vx, vy, vz, omega x, omega y, omega z J = jnp.zeros((*transform.get_batch_axes(), 7, 6)) # Translation terms. J = J.at[..., 4:, :3].set(transform.rotation().as_matrix()) # Rotation terms. J = J.at[..., :4, 3:6].set( rplus_jacobian_parameters_wrt_delta(transform.rotation()) ) else: assert False, f"Unsupported type: {type(transform)}" assert J.shape == ( *transform.get_batch_axes(), transform.parameters_dim, transform.tangent_dim, ) return J