jaxlie._base

Module Contents

Classes

MatrixLieGroup

Interface definition for matrix Lie groups.

SOBase

Base class for special orthogonal groups.

SEBase

Base class for special Euclidean groups.

Attributes

ContainedSOType

class jaxlie._base.MatrixLieGroup(parameters)[source]

Bases: abc.ABC

Inheritance diagram of jaxlie.MatrixLieGroup

Interface definition for matrix Lie groups. .. py:attribute:: matrix_dim

type:

ClassVar[int]

Dimension of square matrix output from .as_matrix().

parameters_dim: ClassVar[int]

Dimension of underlying parameters, .parameters().

tangent_dim: ClassVar[int]

Dimension of tangent space.

space_dim: ClassVar[int]

Dimension of coordinates that can be transformed.

__matmul__(other: typing_extensions.Self) typing_extensions.Self[source]
__matmul__(other: jaxlie.hints.Array) jax.Array

Overload for the @ operator.

Switches between the group action (.apply()) and multiplication (.multiply()) based on the type of other.

abstract classmethod identity(batch_axes=())[source]

Returns identity element.

Parameters:

batch_axes (Tuple[int, Ellipsis]) – Any leading batch axes for the output transform.

Returns:

Identity element.

Return type:

typing_extensions.Self

abstract classmethod from_matrix(matrix)[source]

Get group member from matrix representation.

Parameters:

matrix (jaxlie.hints.Array) – Matrix representaiton.

Returns:

Group member.

Return type:

typing_extensions.Self

abstract as_matrix()[source]

Get transformation as a matrix. Homogeneous for SE groups.

Return type:

jax.Array

abstract parameters()[source]

Get underlying representation.

Return type:

jax.Array

abstract apply(target)[source]

Applies group action to a point.

Parameters:

target (jaxlie.hints.Array) – Point to transform.

Returns:

Transformed point.

Return type:

jax.Array

abstract multiply(other)[source]

Composes this transformation with another.

Returns:

self @ other

Parameters:

other (typing_extensions.Self) –

Return type:

typing_extensions.Self

abstract classmethod exp(tangent)[source]

Computes expm(wedge(tangent)).

Parameters:

tangent (jaxlie.hints.Array) – Tangent vector to take the exponential of.

Returns:

Output.

Return type:

typing_extensions.Self

abstract log()[source]

Computes vee(logm(transformation matrix)).

Returns:

Output. Shape should be (tangent_dim,).

Return type:

jax.Array

abstract adjoint()[source]

Computes the adjoint, which transforms tangent vectors between tangent spaces.

More precisely, for a transform GroupType:

GroupType @ exp(omega) = exp(Adj_T @ omega) @ GroupType

In robotics, typically used for transforming twists, wrenches, and Jacobians across different reference frames.

Returns:

Output. Shape should be (tangent_dim, tangent_dim).

Return type:

jax.Array

abstract inverse()[source]

Computes the inverse of our transform.

Returns:

Output.

Return type:

typing_extensions.Self

abstract normalize()[source]

Normalize/projects values and returns.

Returns:

Normalized group member.

Return type:

typing_extensions.Self

abstract classmethod sample_uniform(key, batch_axes=())[source]

Draw a uniform sample from the group. Translations (if applicable) are in the range [-1, 1].

Parameters:
  • key (jax.Array) – PRNG key, as returned by jax.random.PRNGKey().

  • batch_axes (Tuple[int, Ellipsis]) – Any leading batch axes for the output transforms. Each sampled transform will be different.

Returns:

Sampled group member.

Return type:

typing_extensions.Self

get_batch_axes()[source]

Return any leading batch axes in contained parameters. If an array of shape (100, 4) is placed in the wxyz field of an SO3 object, for example, this will return (100,).

Return type:

Tuple[int, Ellipsis]

class jaxlie._base.SOBase(parameters)[source]

Bases: MatrixLieGroup

Inheritance diagram of jaxlie.SOBase

Base class for special orthogonal groups.

jaxlie._base.ContainedSOType
class jaxlie._base.SEBase(parameters)[source]

Bases: Generic[ContainedSOType], MatrixLieGroup

Inheritance diagram of jaxlie.SEBase

Base class for special Euclidean groups.

Each SE(N) group member contains an SO(N) rotation, as well as an N-dimensional translation vector. .. py:method:: from_rotation_and_translation(rotation, translation)

classmethod:

abstractmethod:

Construct a rigid transform from a rotation and a translation.

param rotation:

Rotation term.

param translation:

translation term.

returns:

Constructed transformation.

classmethod from_rotation(rotation)[source]
Parameters:

rotation (ContainedSOType) –

Return type:

typing_extensions.Self

classmethod from_translation(translation)[source]
Parameters:

translation (jaxlie.hints.Array) –

Return type:

typing_extensions.Self

abstract rotation()[source]

Returns a transform’s rotation term.

Return type:

ContainedSOType

abstract translation()[source]

Returns a transform’s translation term.

Return type:

jax.Array

apply(target)[source]

Applies group action to a point.

Parameters:

target (jaxlie.hints.Array) – Point to transform.

Returns:

Transformed point.

Return type:

jax.Array

multiply(other)[source]

Composes this transformation with another.

Returns:

self @ other

Parameters:

other (typing_extensions.Self) –

Return type:

typing_extensions.Self

inverse()[source]

Computes the inverse of our transform.

Returns:

Output.

Return type:

typing_extensions.Self

normalize()[source]

Normalize/projects values and returns.

Returns:

Normalized group member.

Return type:

typing_extensions.Self