from __future__ import annotations
import jax
import jax_dataclasses as jdc
from jax import numpy as jnp
from typing_extensions import Annotated, override
from . import _base, hints
from .utils import get_epsilon, register_lie_group
[docs]
@register_lie_group(
matrix_dim=3,
parameters_dim=4,
tangent_dim=3,
space_dim=3,
)
@jdc.pytree_dataclass
class SO3(jdc.EnforcedAnnotationsMixin, _base.SOBase):
"""Special orthogonal group for 3D rotations.
Internal parameterization is `(qw, qx, qy, qz)`. Tangent parameterization is
`(omega_x, omega_y, omega_z)`.
"""
# SO3-specific.
wxyz: Annotated[
jax.Array,
(..., 4), # Shape.
jnp.floating, # Data-type.
]
"""Internal parameters. `(w, x, y, z)` quaternion."""
[docs]
@override
def __repr__(self) -> str:
wxyz = jnp.round(self.wxyz, 5)
return f"{self.__class__.__name__}(wxyz={wxyz})"
[docs]
@staticmethod
def from_x_radians(theta: hints.Scalar) -> SO3:
"""Generates a x-axis rotation.
Args:
angle: X rotation, in radians.
Returns:
Output.
"""
return SO3.exp(jnp.array([theta, 0.0, 0.0]))
[docs]
@staticmethod
def from_y_radians(theta: hints.Scalar) -> SO3:
"""Generates a y-axis rotation.
Args:
angle: Y rotation, in radians.
Returns:
Output.
"""
return SO3.exp(jnp.array([0.0, theta, 0.0]))
[docs]
@staticmethod
def from_z_radians(theta: hints.Scalar) -> SO3:
"""Generates a z-axis rotation.
Args:
angle: Z rotation, in radians.
Returns:
Output.
"""
return SO3.exp(jnp.array([0.0, 0.0, theta]))
[docs]
@staticmethod
def from_rpy_radians(
roll: hints.Scalar,
pitch: hints.Scalar,
yaw: hints.Scalar,
) -> SO3:
"""Generates a transform from a set of Euler angles. Uses the ZYX mobile robot
convention.
Args:
roll: X rotation, in radians. Applied first.
pitch: Y rotation, in radians. Applied second.
yaw: Z rotation, in radians. Applied last.
Returns:
Output.
"""
return (
SO3.from_z_radians(yaw)
[docs]
@ SO3.from_y_radians(pitch)
@ SO3.from_x_radians(roll)
)
@staticmethod
def from_quaternion_xyzw(xyzw: hints.Array) -> SO3:
"""Construct a rotation from an `xyzw` quaternion.
Note that `wxyz` quaternions can be constructed using the default dataclass
constructor.
Args:
xyzw: xyzw quaternion. Shape should be (4,).
Returns:
Output.
"""
assert xyzw.shape == (4,)
return SO3(jnp.roll(xyzw, shift=1))
[docs]
def as_quaternion_xyzw(self) -> jax.Array:
"""Grab parameters as xyzw quaternion."""
return jnp.roll(self.wxyz, shift=-1)
[docs]
def as_rpy_radians(self) -> hints.RollPitchYaw:
"""Computes roll, pitch, and yaw angles. Uses the ZYX mobile robot convention.
Returns:
Named tuple containing Euler angles in radians.
"""
return hints.RollPitchYaw(
roll=self.compute_roll_radians(),
pitch=self.compute_pitch_radians(),
yaw=self.compute_yaw_radians(),
)
[docs]
def compute_roll_radians(self) -> jax.Array:
"""Compute roll angle. Uses the ZYX mobile robot convention.
Returns:
Euler angle in radians.
"""
# https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion
q0, q1, q2, q3 = self.wxyz
return jnp.arctan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1**2 + q2**2))
[docs]
def compute_pitch_radians(self) -> jax.Array:
"""Compute pitch angle. Uses the ZYX mobile robot convention.
Returns:
Euler angle in radians.
"""
# https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion
q0, q1, q2, q3 = self.wxyz
return jnp.arcsin(2 * (q0 * q2 - q3 * q1))
[docs]
def compute_yaw_radians(self) -> jax.Array:
"""Compute yaw angle. Uses the ZYX mobile robot convention.
Returns:
Euler angle in radians.
"""
# https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles#Quaternion_to_Euler_angles_conversion
q0, q1, q2, q3 = self.wxyz
return jnp.arctan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2**2 + q3**2))
# Factory.
[docs]
@staticmethod
@override
def identity() -> SO3:
return SO3(wxyz=jnp.array([1.0, 0.0, 0.0, 0.0]))
[docs]
@staticmethod
@override
def from_matrix(matrix: hints.Array) -> SO3:
assert matrix.shape == (3, 3)
# Modified from:
# > "Converting a Rotation Matrix to a Quaternion" from Mike Day
# > https://d3cw3dd2w32x2b.cloudfront.net/wp-content/uploads/2015/01/matrix-to-quat.pdf
def case0(m):
t = 1 + m[0, 0] - m[1, 1] - m[2, 2]
q = jnp.array(
[
m[2, 1] - m[1, 2],
t,
m[1, 0] + m[0, 1],
m[0, 2] + m[2, 0],
]
)
return t, q
def case1(m):
t = 1 - m[0, 0] + m[1, 1] - m[2, 2]
q = jnp.array(
[
m[0, 2] - m[2, 0],
m[1, 0] + m[0, 1],
t,
m[2, 1] + m[1, 2],
]
)
return t, q
def case2(m):
t = 1 - m[0, 0] - m[1, 1] + m[2, 2]
q = jnp.array(
[
m[1, 0] - m[0, 1],
m[0, 2] + m[2, 0],
m[2, 1] + m[1, 2],
t,
]
)
return t, q
def case3(m):
t = 1 + m[0, 0] + m[1, 1] + m[2, 2]
q = jnp.array(
[
t,
m[2, 1] - m[1, 2],
m[0, 2] - m[2, 0],
m[1, 0] - m[0, 1],
]
)
return t, q
# Compute four cases, then pick the most precise one.
# Probably worth revisiting this!
case0_t, case0_q = case0(matrix)
case1_t, case1_q = case1(matrix)
case2_t, case2_q = case2(matrix)
case3_t, case3_q = case3(matrix)
cond0 = matrix[2, 2] < 0
cond1 = matrix[0, 0] > matrix[1, 1]
cond2 = matrix[0, 0] < -matrix[1, 1]
t = jnp.where(
cond0,
jnp.where(cond1, case0_t, case1_t),
jnp.where(cond2, case2_t, case3_t),
)
q = jnp.where(
cond0,
jnp.where(cond1, case0_q, case1_q),
jnp.where(cond2, case2_q, case3_q),
)
# We can also choose to branch, but this is slower.
# t, q = jax.lax.cond(
# matrix[2, 2] < 0,
# true_fun=lambda matrix: jax.lax.cond(
# matrix[0, 0] > matrix[1, 1],
# true_fun=case0,
# false_fun=case1,
# operand=matrix,
# ),
# false_fun=lambda matrix: jax.lax.cond(
# matrix[0, 0] < -matrix[1, 1],
# true_fun=case2,
# false_fun=case3,
# operand=matrix,
# ),
# operand=matrix,
# )
return SO3(wxyz=q * 0.5 / jnp.sqrt(t))
# Accessors.
[docs]
@override
def as_matrix(self) -> jax.Array:
norm = self.wxyz @ self.wxyz
q = self.wxyz * jnp.sqrt(2.0 / norm)
q = jnp.outer(q, q)
return jnp.array(
[
[1.0 - q[2, 2] - q[3, 3], q[1, 2] - q[3, 0], q[1, 3] + q[2, 0]],
[q[1, 2] + q[3, 0], 1.0 - q[1, 1] - q[3, 3], q[2, 3] - q[1, 0]],
[q[1, 3] - q[2, 0], q[2, 3] + q[1, 0], 1.0 - q[1, 1] - q[2, 2]],
]
)
[docs]
@override
def parameters(self) -> jax.Array:
return self.wxyz
# Operations.
[docs]
@override
def apply(self, target: hints.Array) -> jax.Array:
assert target.shape == (3,)
# Compute using quaternion multiplys.
padded_target = jnp.concatenate([jnp.zeros(1), target])
return (self @ SO3(wxyz=padded_target) @ self.inverse()).wxyz[1:]
[docs]
@override
def multiply(self, other: SO3) -> SO3:
w0, x0, y0, z0 = self.wxyz
w1, x1, y1, z1 = other.wxyz
return SO3(
wxyz=jnp.array(
[
-x0 * x1 - y0 * y1 - z0 * z1 + w0 * w1,
x0 * w1 + y0 * z1 - z0 * y1 + w0 * x1,
-x0 * z1 + y0 * w1 + z0 * x1 + w0 * y1,
x0 * y1 - y0 * x1 + z0 * w1 + w0 * z1,
]
)
)
[docs]
@staticmethod
@override
def exp(tangent: hints.Array) -> SO3:
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L583
assert tangent.shape == (3,)
theta_squared = tangent @ tangent
theta_pow_4 = theta_squared * theta_squared
use_taylor = theta_squared < get_epsilon(tangent.dtype)
# Shim to avoid NaNs in jnp.where branches, which cause failures for
# reverse-mode AD.
safe_theta = jnp.sqrt(
jnp.where(
use_taylor,
1.0, # Any constant value should do here.
theta_squared,
)
)
safe_half_theta = 0.5 * safe_theta
real_factor = jnp.where(
use_taylor,
1.0 - theta_squared / 8.0 + theta_pow_4 / 384.0,
jnp.cos(safe_half_theta),
)
imaginary_factor = jnp.where(
use_taylor,
0.5 - theta_squared / 48.0 + theta_pow_4 / 3840.0,
jnp.sin(safe_half_theta) / safe_theta,
)
return SO3(
wxyz=jnp.concatenate(
[
real_factor[None],
imaginary_factor * tangent,
]
)
)
[docs]
@override
def log(self) -> jax.Array:
# Reference:
# > https://github.com/strasdat/Sophus/blob/a0fe89a323e20c42d3cecb590937eb7a06b8343a/sophus/so3.hpp#L247
w = self.wxyz[..., 0]
norm_sq = self.wxyz[..., 1:] @ self.wxyz[..., 1:]
use_taylor = norm_sq < get_epsilon(norm_sq.dtype)
# Shim to avoid NaNs in jnp.where branches, which cause failures for
# reverse-mode AD.
norm_safe = jnp.sqrt(
jnp.where(
use_taylor,
1.0, # Any non-zero value should do here.
norm_sq,
)
)
w_safe = jnp.where(use_taylor, w, 1.0)
atan_n_over_w = jnp.arctan2(
jnp.where(w < 0, -norm_safe, norm_safe),
jnp.abs(w),
)
atan_factor = jnp.where(
use_taylor,
2.0 / w_safe - 2.0 / 3.0 * norm_sq / w_safe**3,
jnp.where(
jnp.abs(w) < get_epsilon(w.dtype),
jnp.where(w > 0, 1.0, -1.0) * jnp.pi / norm_safe,
2.0 * atan_n_over_w / norm_safe,
),
)
return atan_factor * self.wxyz[1:]
[docs]
@override
def adjoint(self) -> jax.Array:
return self.as_matrix()
[docs]
@override
def inverse(self) -> SO3:
# Negate complex terms.
return SO3(wxyz=self.wxyz * jnp.array([1, -1, -1, -1]))
[docs]
@override
def normalize(self) -> SO3:
return SO3(wxyz=self.wxyz / jnp.linalg.norm(self.wxyz))