Source code for jaxlie.utils._utils

from typing import TYPE_CHECKING, Callable, Tuple, Type, TypeVar, Union, cast

import jax_dataclasses as jdc
from jax import numpy as jnp

from jaxlie.hints import Array

if TYPE_CHECKING:
    from .._base import MatrixLieGroup


T = TypeVar("T", bound="MatrixLieGroup")


[docs] def get_epsilon(dtype: jnp.dtype) -> float: """Helper for grabbing type-specific precision constants. Args: dtype: Datatype. Returns: Output float. """ return { jnp.dtype("float32"): 1e-5, jnp.dtype("float64"): 1e-10, }[dtype]
[docs] def register_lie_group( *, matrix_dim: int, parameters_dim: int, tangent_dim: int, space_dim: int, ) -> Callable[[Type[T]], Type[T]]: """Decorator for registering Lie group dataclasses. Sets dimensionality class variables, and marks all methods for JIT compilation. """ def _wrap(cls: Type[T]) -> Type[T]: # Register dimensions as class attributes. cls.matrix_dim = matrix_dim cls.parameters_dim = parameters_dim cls.tangent_dim = tangent_dim cls.space_dim = space_dim # JIT all methods. for f in filter( lambda f: not f.startswith("_") and callable(getattr(cls, f)) and f != "get_batch_axes", # Avoid returning tracers. dir(cls), ): setattr(cls, f, jdc.jit(getattr(cls, f))) return cls return _wrap
TupleOfBroadcastable = TypeVar( "TupleOfBroadcastable", bound="Tuple[Union[MatrixLieGroup, Array], ...]", )
[docs] def broadcast_leading_axes(inputs: TupleOfBroadcastable) -> TupleOfBroadcastable: """Broadcast leading axes of arrays. Takes tuples of either: - an array, which we assume has shape (*, D). - a Lie group object.""" from .._base import MatrixLieGroup array_inputs = [ ( (x.parameters(), (x.parameters_dim,)) if isinstance(x, MatrixLieGroup) else (x, x.shape[-1:]) ) for x in inputs ] for array, shape_suffix in array_inputs: assert array.shape[-len(shape_suffix) :] == shape_suffix batch_axes = jnp.broadcast_shapes( *[array.shape[: -len(suffix)] for array, suffix in array_inputs] ) broadcasted_arrays = tuple( jnp.broadcast_to(array, batch_axes + shape_suffix) for (array, shape_suffix) in array_inputs ) return cast( TupleOfBroadcastable, tuple( array if not isinstance(inp, MatrixLieGroup) else type(inp)(array) for array, inp in zip(broadcasted_arrays, inputs) ), )