Source code for jaxlie._base

import abc
from typing import ClassVar, Generic, Tuple, Type, TypeVar, Union, overload

import jax
import numpy as onp
from typing_extensions import final, override

from . import hints

GroupType = TypeVar("GroupType", bound="MatrixLieGroup")
SEGroupType = TypeVar("SEGroupType", bound="SEBase")


[docs] class MatrixLieGroup(abc.ABC): """Interface definition for matrix Lie groups.""" # Class properties. # > These will be set in `_utils.register_lie_group()`. matrix_dim: ClassVar[int] """Dimension of square matrix output from `.as_matrix()`.""" parameters_dim: ClassVar[int] """Dimension of underlying parameters, `.parameters()`.""" tangent_dim: ClassVar[int] """Dimension of tangent space.""" space_dim: ClassVar[int] """Dimension of coordinates that can be transformed.""" def __init__( # Notes: # - For the constructor signature to be consistent with subclasses, `parameters` # should be marked as positional-only. But this isn't possible in Python 3.7. # - This method is implicitly overriden by the dataclass decorator and # should _not_ be marked abstract. self, parameters: jax.Array, ): """Construct a group object from its underlying parameters.""" raise NotImplementedError() # Shared implementations. @overload def __matmul__(self: GroupType, other: GroupType) -> GroupType: ... @overload def __matmul__(self, other: hints.Array) -> jax.Array: ...
[docs] def __matmul__( self: GroupType, other: Union[GroupType, hints.Array] ) -> Union[GroupType, jax.Array]: """Overload for the `@` operator. Switches between the group action (`.apply()`) and multiplication (`.multiply()`) based on the type of `other`. """ if isinstance(other, (onp.ndarray, jax.Array)): return self.apply(target=other) elif isinstance(other, MatrixLieGroup): assert self.space_dim == other.space_dim return self.multiply(other=other) else: assert False, f"Invalid argument type for `@` operator: {type(other)}"
# Factory.
[docs] @classmethod @abc.abstractmethod def identity(cls: Type[GroupType]) -> GroupType: """Returns identity element. Returns: Identity element. """
[docs] @classmethod @abc.abstractmethod def from_matrix(cls: Type[GroupType], matrix: hints.Array) -> GroupType: """Get group member from matrix representation. Args: matrix: Matrix representaiton. Returns: Group member. """
# Accessors.
[docs] @abc.abstractmethod def as_matrix(self) -> jax.Array: """Get transformation as a matrix. Homogeneous for SE groups."""
[docs] @abc.abstractmethod def parameters(self) -> jax.Array: """Get underlying representation."""
# Operations.
[docs] @abc.abstractmethod def apply(self, target: hints.Array) -> jax.Array: """Applies group action to a point. Args: target: Point to transform. Returns: Transformed point. """
[docs] @abc.abstractmethod def multiply(self: GroupType, other: GroupType) -> GroupType: """Composes this transformation with another. Returns: self @ other """
[docs] @classmethod @abc.abstractmethod def exp(cls: Type[GroupType], tangent: hints.Array) -> GroupType: """Computes `expm(wedge(tangent))`. Args: tangent: Tangent vector to take the exponential of. Returns: Output. """
[docs] @abc.abstractmethod def log(self) -> jax.Array: """Computes `vee(logm(transformation matrix))`. Returns: Output. Shape should be `(tangent_dim,)`. """
[docs] @abc.abstractmethod def adjoint(self) -> jax.Array: """Computes the adjoint, which transforms tangent vectors between tangent spaces. More precisely, for a transform `GroupType`: ``` GroupType @ exp(omega) = exp(Adj_T @ omega) @ GroupType ``` In robotics, typically used for transforming twists, wrenches, and Jacobians across different reference frames. Returns: Output. Shape should be `(tangent_dim, tangent_dim)`. """
[docs] @abc.abstractmethod def inverse(self: GroupType) -> GroupType: """Computes the inverse of our transform. Returns: Output. """
[docs] @abc.abstractmethod def normalize(self: GroupType) -> GroupType: """Normalize/projects values and returns. Returns: GroupType: Normalized group member. """
[docs] @classmethod @abc.abstractmethod def sample_uniform(cls: Type[GroupType], key: jax.Array) -> GroupType: """Draw a uniform sample from the group. Translations (if applicable) are in the range [-1, 1]. Args: key: PRNG key, as returned by `jax.random.PRNGKey()`. Returns: Sampled group member. """
[docs] @abc.abstractmethod def get_batch_axes(self) -> Tuple[int, ...]: """Return any leading batch axes in contained parameters. If an array of shape `(100, 4)` is placed in the wxyz field of an SO3 object, for example, this will return `(100,)`. This should generally be implemented by `jdc.EnforcedAnnotationsMixin`."""
[docs] class SOBase(MatrixLieGroup): """Base class for special orthogonal groups."""
ContainedSOType = TypeVar("ContainedSOType", bound=SOBase)
[docs] class SEBase(Generic[ContainedSOType], MatrixLieGroup): """Base class for special Euclidean groups. Each SE(N) group member contains an SO(N) rotation, as well as an N-dimensional translation vector. """ # SE-specific interface. @classmethod @abc.abstractmethod def from_rotation_and_translation( cls: Type[SEGroupType], rotation: ContainedSOType, translation: hints.Array, ) -> SEGroupType: """Construct a rigid transform from a rotation and a translation. Args: rotation: Rotation term. translation: translation term. Returns: Constructed transformation. """
[docs] @final @classmethod def from_rotation(cls: Type[SEGroupType], rotation: ContainedSOType) -> SEGroupType: return cls.from_rotation_and_translation( rotation=rotation, translation=onp.zeros(cls.space_dim, dtype=rotation.parameters().dtype), )
[docs] @abc.abstractmethod def rotation(self) -> ContainedSOType: """Returns a transform's rotation term."""
[docs] @abc.abstractmethod def translation(self) -> jax.Array: """Returns a transform's translation term."""
# Overrides.
[docs] @final @override def apply(self, target: hints.Array) -> jax.Array: return self.rotation() @ target + self.translation() # type: ignore
[docs] @final @override def multiply(self: SEGroupType, other: SEGroupType) -> SEGroupType: return type(self).from_rotation_and_translation( rotation=self.rotation() @ other.rotation(), translation=(self.rotation() @ other.translation()) + self.translation(), )
[docs] @final @override def inverse(self: SEGroupType) -> SEGroupType: R_inv = self.rotation().inverse() return type(self).from_rotation_and_translation( rotation=R_inv, translation=-(R_inv @ self.translation()), )
[docs] @final @override def normalize(self: SEGroupType) -> SEGroupType: return type(self).from_rotation_and_translation( rotation=self.rotation().normalize(), translation=self.translation(), )