from __future__ import annotations
from typing import Tuple
import jax
import jax_dataclasses as jdc
from jax import numpy as jnp
from typing_extensions import override
from . import _base, hints
from .utils import broadcast_leading_axes, register_lie_group
[docs]
@register_lie_group(
matrix_dim=2,
parameters_dim=2,
tangent_dim=1,
space_dim=2,
)
@jdc.pytree_dataclass
class SO2(_base.SOBase):
"""Special orthogonal group for 2D rotations. Broadcasting rules are the
same as for `numpy`.
Internal parameterization is `(cos, sin)`. Tangent parameterization is `(omega,)`.
"""
# SO2-specific.
unit_complex: jax.Array
"""Internal parameters. `(cos, sin)`. Shape should be `(*, 2)`."""
[docs]
@override
def __repr__(self) -> str:
unit_complex = jnp.round(self.unit_complex, 5)
return f"{self.__class__.__name__}(unit_complex={unit_complex})"
[docs]
@staticmethod
def from_radians(theta: hints.Scalar) -> SO2:
"""Construct a rotation object from a scalar angle."""
cos = jnp.cos(theta)
sin = jnp.sin(theta)
return SO2(unit_complex=jnp.stack([cos, sin], axis=-1))
[docs]
def as_radians(self) -> jax.Array:
"""Compute a scalar angle from a rotation object."""
radians = self.log()[..., 0]
return radians
# Factory.
[docs]
@classmethod
@override
def identity(cls, batch_axes: jdc.Static[Tuple[int, ...]] = ()) -> SO2:
return SO2(
unit_complex=jnp.stack(
[jnp.ones(batch_axes), jnp.zeros(batch_axes)], axis=-1
)
)
[docs]
@classmethod
@override
def from_matrix(cls, matrix: hints.Array) -> SO2:
assert matrix.shape[-2:] == (2, 2)
return SO2(unit_complex=jnp.asarray(matrix[..., :, 0]))
# Accessors.
[docs]
@override
def as_matrix(self) -> jax.Array:
cos_sin = self.unit_complex
out = jnp.stack(
[
# [cos, -sin],
cos_sin * jnp.array([1, -1]),
# [sin, cos],
cos_sin[..., ::-1],
],
axis=-2,
)
assert out.shape == (*self.get_batch_axes(), 2, 2)
return out
[docs]
@override
def parameters(self) -> jax.Array:
return self.unit_complex
# Operations.
[docs]
@override
def apply(self, target: hints.Array) -> jax.Array:
assert target.shape[-1:] == (2,)
self, target = broadcast_leading_axes((self, target))
return jnp.einsum("...ij,...j->...i", self.as_matrix(), target)
[docs]
@override
def multiply(self, other: SO2) -> SO2:
return SO2(
unit_complex=jnp.einsum(
"...ij,...j->...i", self.as_matrix(), other.unit_complex
)
)
[docs]
@classmethod
@override
def exp(cls, tangent: hints.Array) -> SO2:
assert tangent.shape[-1] == 1
cos = jnp.cos(tangent)
sin = jnp.sin(tangent)
return SO2(unit_complex=jnp.concatenate([cos, sin], axis=-1))
[docs]
@override
def log(self) -> jax.Array:
return jnp.arctan2(
self.unit_complex[..., 1, None], self.unit_complex[..., 0, None]
)
[docs]
@override
def adjoint(self) -> jax.Array:
return jnp.ones((*self.get_batch_axes(), 1, 1))
[docs]
@override
def inverse(self) -> SO2:
return SO2(unit_complex=self.unit_complex * jnp.array([1, -1]))
[docs]
@override
def normalize(self) -> SO2:
return SO2(
unit_complex=self.unit_complex
/ jnp.linalg.norm(self.unit_complex, axis=-1, keepdims=True)
)
[docs]
@override
def jlog(self) -> jax.Array:
batch_axes = self.get_batch_axes()
ones = jnp.ones(batch_axes)
return ones[..., None, None]