from __future__ import annotations
import jax
import jax_dataclasses as jdc
from jax import numpy as jnp
from typing_extensions import Annotated, override
from . import _base, hints
from .utils import register_lie_group
[docs]
@register_lie_group(
matrix_dim=2,
parameters_dim=2,
tangent_dim=1,
space_dim=2,
)
@jdc.pytree_dataclass
class SO2(jdc.EnforcedAnnotationsMixin, _base.SOBase):
"""Special orthogonal group for 2D rotations.
Internal parameterization is `(cos, sin)`. Tangent parameterization is `(omega,)`.
"""
# SO2-specific.
unit_complex: Annotated[
jax.Array,
(..., 2), # Shape.
jnp.floating, # Data-type.
]
"""Internal parameters. `(cos, sin)`."""
[docs]
@override
def __repr__(self) -> str:
unit_complex = jnp.round(self.unit_complex, 5)
return f"{self.__class__.__name__}(unit_complex={unit_complex})"
[docs]
@staticmethod
def from_radians(theta: hints.Scalar) -> SO2:
"""Construct a rotation object from a scalar angle."""
cos = jnp.cos(theta)
sin = jnp.sin(theta)
return SO2(unit_complex=jnp.array([cos, sin]))
[docs]
def as_radians(self) -> jax.Array:
"""Compute a scalar angle from a rotation object."""
radians = self.log()[..., 0]
return radians
# Factory.
[docs]
@staticmethod
@override
def identity() -> SO2:
return SO2(unit_complex=jnp.array([1.0, 0.0]))
[docs]
@staticmethod
@override
def from_matrix(matrix: hints.Array) -> SO2:
assert matrix.shape == (2, 2)
return SO2(unit_complex=jnp.asarray(matrix[:, 0]))
# Accessors.
[docs]
@override
def as_matrix(self) -> jax.Array:
cos_sin = self.unit_complex
out = jnp.array(
[
# [cos, -sin],
cos_sin * jnp.array([1, -1]),
# [sin, cos],
cos_sin[::-1],
]
)
assert out.shape == (2, 2)
return out
[docs]
@override
def parameters(self) -> jax.Array:
return self.unit_complex
# Operations.
[docs]
@override
def apply(self, target: hints.Array) -> jax.Array:
assert target.shape == (2,)
return self.as_matrix() @ target # type: ignore
[docs]
@override
def multiply(self, other: SO2) -> SO2:
return SO2(unit_complex=self.as_matrix() @ other.unit_complex)
[docs]
@staticmethod
@override
def exp(tangent: hints.Array) -> SO2:
(theta,) = tangent
cos = jnp.cos(theta)
sin = jnp.sin(theta)
return SO2(unit_complex=jnp.array([cos, sin]))
[docs]
@override
def log(self) -> jax.Array:
return jnp.arctan2(
self.unit_complex[..., 1, None], self.unit_complex[..., 0, None]
)
[docs]
@override
def adjoint(self) -> jax.Array:
return jnp.eye(1)
[docs]
@override
def inverse(self) -> SO2:
return SO2(unit_complex=self.unit_complex * jnp.array([1, -1]))
[docs]
@override
def normalize(self) -> SO2:
return SO2(unit_complex=self.unit_complex / jnp.linalg.norm(self.unit_complex))