jax.vmap Usage

"""jaxlie implements numpy-style broadcasting for all operations. For more
explicit vectorization, we can also use vmap function transformations.

Omitted for brevity here, but in practice we usually want to JIT after
vmapping."""

import jax
import numpy as onp

from jaxlie import SO3

N = 100

#############################
# (1) Setup.
#############################

# We start by creating two rotation objects:
# - R_single contains a standard single rotation.
# - R_stacked contained `N` rotations stacked together! Note that all Lie group objects
#   are PyTrees, so this has the same structure as R_single but with a batch axis in the
#   contained parameters array.

R_single = SO3.from_x_radians(onp.pi / 2.0)
assert R_single.wxyz.shape == (4,)

R_stacked = jax.vmap(SO3.from_x_radians)(
    onp.random.uniform(low=-onp.pi, high=onp.pi, size=(N,))
)
assert R_stacked.wxyz.shape == (N, 4)

# We can also create two arrays containing points: one is a single point, the other is
# `N` points stacked.
p_single = onp.random.uniform(size=(3,))
p_stacked = onp.random.uniform(size=(N, 3))

#############################
# (2) Applying 1 transformation to 1 point.
#############################

# Recall that these two approaches to transforming a point:
p_transformed_single = R_single @ p_single
assert p_transformed_single.shape == (3,)
p_transformed_single = R_single.apply(p_single)
assert p_transformed_single.shape == (3,)

# Are just syntactic sugar for calling:
p_transformed_single = SO3.apply(R_single, p_single)
assert p_transformed_single.shape == (3,)


#############################
# (3) Applying 1 transformation to N points.
#############################

# This follows standard vmap semantics!
p_transformed_stacked = jax.vmap(R_single.apply)(p_stacked)
assert p_transformed_stacked.shape == (N, 3)

# Note that this is equivalent to:
p_transformed_stacked = jax.vmap(lambda p: SO3.apply(R_single, p))(p_stacked)
assert p_transformed_stacked.shape == (N, 3)

# We can also just rely on broadcasting.
p_transformed_stacked = R_single @ p_stacked
assert p_transformed_stacked.shape == (N, 3)

#############################
# (4) Applying N transformations to N points.
#############################

# R_stacked and p_stacked both have an (N,) batch dimension compared to their "single"
# counterparts. We can therefore vmap over both arguments of SO3.apply:
p_transformed_stacked = jax.vmap(SO3.apply)(R_stacked, p_stacked)
assert p_transformed_stacked.shape == (N, 3)

# We can also just rely on broadcasting.
p_transformed_stacked = R_stacked @ p_stacked
assert p_transformed_stacked.shape == (N, 3)

#############################
# (5) Applying N transformations to 1 point.
#############################

p_transformed_stacked = jax.vmap(lambda R: SO3.apply(R, p_single))(R_stacked)
assert p_transformed_stacked.shape == (N, 3)

# We can also just rely on broadcasting.
p_transformed_stacked = R_stacked @ p_single[None, :]
assert p_transformed_stacked.shape == (N, 3)

#############################
# (6) Multiplying transformations.
#############################

# The same concepts as above apply to other operations!
# For multiplication, these are all the same:
assert (R_single @ R_single).wxyz.shape == (4,)
assert (R_single.multiply(R_single)).wxyz.shape == (4,)
assert (SO3.multiply(R_single, R_single)).wxyz.shape == (4,)

# And therefore we can also do 1 x N multiplication:
assert (jax.vmap(R_single.multiply)(R_stacked)).wxyz.shape == (N, 4)
assert (jax.vmap(lambda R: SO3.multiply(R_single, R))(R_stacked)).wxyz.shape == (N, 4)

# Or N x N multiplication:
assert (jax.vmap(SO3.multiply)(R_stacked, R_stacked)).wxyz.shape == (N, 4)

# Or N x 1 multiplication:
assert (jax.vmap(lambda R: SO3.multiply(R, R_single))(R_stacked)).wxyz.shape == (N, 4)

# Again, broadcasting also works.
assert (R_stacked @ R_stacked).wxyz.shape == (N, 4)
assert (R_stacked @ SO3(R_single.wxyz[None, :])).wxyz.shape == (N, 4)