from __future__ import annotations
from typing import cast
import jax
import jax_dataclasses as jdc
from jax import numpy as jnp
from typing_extensions import Annotated, override
from . import _base, hints
from ._so3 import SO3
from .utils import get_epsilon, register_lie_group
def _skew(omega: hints.Array) -> jax.Array:
"""Returns the skew-symmetric form of a length-3 vector."""
wx, wy, wz = omega
return jnp.array(
[
[0.0, -wz, wy],
[wz, 0.0, -wx],
[-wy, wx, 0.0],
]
)
[docs]
@register_lie_group(
matrix_dim=4,
parameters_dim=7,
tangent_dim=6,
space_dim=3,
)
@jdc.pytree_dataclass
class SE3(jdc.EnforcedAnnotationsMixin, _base.SEBase[SO3]):
"""Special Euclidean group for proper rigid transforms in 3D.
Internal parameterization is `(qw, qx, qy, qz, x, y, z)`. Tangent parameterization
is `(vx, vy, vz, omega_x, omega_y, omega_z)`.
"""
# SE3-specific.
wxyz_xyz: Annotated[
jax.Array,
(..., 7), # Shape.
jnp.floating, # Data-type.
]
"""Internal parameters. wxyz quaternion followed by xyz translation."""
[docs]
@override
def __repr__(self) -> str:
quat = jnp.round(self.wxyz_xyz[..., :4], 5)
trans = jnp.round(self.wxyz_xyz[..., 4:], 5)
return f"{self.__class__.__name__}(wxyz={quat}, xyz={trans})"
# SE-specific.
[docs]
@staticmethod
@override
def from_rotation_and_translation(
rotation: SO3,
translation: hints.Array,
) -> SE3:
assert translation.shape == (3,)
return SE3(wxyz_xyz=jnp.concatenate([rotation.wxyz, translation]))
[docs]
@override
def rotation(self) -> SO3:
return SO3(wxyz=self.wxyz_xyz[..., :4])
[docs]
@override
def translation(self) -> jax.Array:
return self.wxyz_xyz[..., 4:]
# Factory.
[docs]
@staticmethod
@override
def identity() -> SE3:
return SE3(wxyz_xyz=jnp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]))
[docs]
@staticmethod
@override
def from_matrix(matrix: hints.Array) -> SE3:
assert matrix.shape == (4, 4)
# Currently assumes bottom row is [0, 0, 0, 1].
return SE3.from_rotation_and_translation(
rotation=SO3.from_matrix(matrix[:3, :3]),
translation=matrix[:3, 3],
)
# Accessors.
[docs]
@override
def as_matrix(self) -> jax.Array:
return (
jnp.eye(4)
.at[:3, :3]
.set(self.rotation().as_matrix())
.at[:3, 3]
.set(self.translation())
)
[docs]
@override
def parameters(self) -> jax.Array:
return self.wxyz_xyz
# Operations.
[docs]
@staticmethod
@override
def exp(tangent: hints.Array) -> SE3:
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L761
# (x, y, z, omega_x, omega_y, omega_z)
assert tangent.shape == (6,)
rotation = SO3.exp(tangent[3:])
theta_squared = tangent[3:] @ tangent[3:]
use_taylor = theta_squared < get_epsilon(theta_squared.dtype)
# Shim to avoid NaNs in jnp.where branches, which cause failures for
# reverse-mode AD.
theta_squared_safe = cast(
jax.Array,
jnp.where(
use_taylor,
1.0, # Any non-zero value should do here.
theta_squared,
),
)
del theta_squared
theta_safe = jnp.sqrt(theta_squared_safe)
skew_omega = _skew(tangent[3:])
V = jnp.where(
use_taylor,
rotation.as_matrix(),
(
jnp.eye(3)
+ (1.0 - jnp.cos(theta_safe)) / (theta_squared_safe) * skew_omega
+ (theta_safe - jnp.sin(theta_safe))
/ (theta_squared_safe * theta_safe)
* (skew_omega @ skew_omega)
),
)
return SE3.from_rotation_and_translation(
rotation=rotation,
translation=V @ tangent[:3],
)
[docs]
@override
def log(self) -> jax.Array:
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/se3.hpp#L223
omega = self.rotation().log()
theta_squared = omega @ omega
use_taylor = theta_squared < get_epsilon(theta_squared.dtype)
skew_omega = _skew(omega)
# Shim to avoid NaNs in jnp.where branches, which cause failures for
# reverse-mode AD.
theta_squared_safe = jnp.where(
use_taylor,
1.0, # Any non-zero value should do here.
theta_squared,
)
del theta_squared
theta_safe = jnp.sqrt(theta_squared_safe)
half_theta_safe = theta_safe / 2.0
V_inv = jnp.where(
use_taylor,
jnp.eye(3) - 0.5 * skew_omega + (skew_omega @ skew_omega) / 12.0,
(
jnp.eye(3)
- 0.5 * skew_omega
+ (
1.0
- theta_safe
* jnp.cos(half_theta_safe)
/ (2.0 * jnp.sin(half_theta_safe))
)
/ theta_squared_safe
* (skew_omega @ skew_omega)
),
)
return jnp.concatenate([V_inv @ self.translation(), omega])
[docs]
@override
def adjoint(self) -> jax.Array:
R = self.rotation().as_matrix()
return jnp.block(
[
[R, _skew(self.translation()) @ R],
[jnp.zeros((3, 3)), R],
]
)