SE(3) Optimization

"""Example that uses helpers in `jaxlie.manifold.*` to compare algorithms for running an
ADAM optimizer on SE(3) variables.

We compare three approaches:

(1) Tangent-space ADAM: computing updates on a local tangent space, which are then
retracted back to the global parameterization at each step. This should generally be the
most stable.

(2) Projected ADAM: running standard ADAM directly on the global parameterization, then
projecting after each step.

(3) Standard ADAM with exponential coordinates: using a log-space underlying
parameterization lets us run ADAM without any modifications.

Note that the number of training steps and learning rate can be configured, see:

    python se3_optimization.py --help

"""

from __future__ import annotations

import time
from typing import List, Literal, Tuple, Union

import jax
import jax_dataclasses as jdc
import matplotlib.pyplot as plt
import optax
import tyro
from jax import numpy as jnp
from typing_extensions import assert_never

import jaxlie


@jdc.pytree_dataclass
class Parameters:
    """Parameters to optimize over, in their global representation. Rotations are
    quaternions under the hood.

    Note that there's redundancy here: given T_ab and T_bc, T_ca can be computed as
    (T_ab @ T_bc).inverse(). Our optimization will be focused on making these redundant
    transforms consistent with each other.
    """

    T_ab: jaxlie.SE3
    T_bc: jaxlie.SE3
    T_ca: jaxlie.SE3


@jdc.pytree_dataclass
class ExponentialCoordinatesParameters:
    """Same as `Parameters`, but using exponential coordinates."""

    log_T_ab: jax.Array
    log_T_bc: jax.Array
    log_T_ca: jax.Array

    @property
    def T_ab(self) -> jaxlie.SE3:
        return jaxlie.SE3.exp(self.log_T_ab)

    @property
    def T_bc(self) -> jaxlie.SE3:
        return jaxlie.SE3.exp(self.log_T_bc)

    @property
    def T_ca(self) -> jaxlie.SE3:
        return jaxlie.SE3.exp(self.log_T_ca)

    @staticmethod
    def from_global(params: Parameters) -> ExponentialCoordinatesParameters:
        return ExponentialCoordinatesParameters(
            params.T_ab.log(),
            params.T_bc.log(),
            params.T_ca.log(),
        )


def compute_loss(
    params: Union[Parameters, ExponentialCoordinatesParameters],
) -> jax.Array:
    """As our loss, we enforce (a) priors on our transforms and (b) a consistency
    constraint."""
    T_ba_prior = jaxlie.SE3.sample_uniform(jax.random.PRNGKey(1))
    T_cb_prior = jaxlie.SE3.sample_uniform(jax.random.PRNGKey(2))

    return jnp.sum(
        # Consistency term.
        (params.T_ab @ params.T_bc @ params.T_ca).log() ** 2
        # Priors.
        + (params.T_ab @ T_ba_prior).log() ** 2
        + (params.T_bc @ T_cb_prior).log() ** 2
    )


Algorithm = Literal["tangent_space", "projected", "exponential_coordinates"]


@jdc.pytree_dataclass
class State:
    params: Union[Parameters, ExponentialCoordinatesParameters]
    optimizer: jdc.Static[optax.GradientTransformation]
    optimizer_state: optax.OptState
    algorithm: jdc.Static[Algorithm]

    @staticmethod
    def initialize(algorithm: Algorithm, learning_rate: float) -> State:
        """Initialize the state of our optimization problem. Note that the transforms
        parameters won't initially be consistent; `T_ab @ T_bc != T_ca.inverse()`.
        """
        prngs = jax.random.split(jax.random.PRNGKey(0), num=1)
        global_params = Parameters(
            jaxlie.SE3.sample_uniform(prngs[0]),
            jaxlie.SE3.sample_uniform(prngs[1]),
            jaxlie.SE3.sample_uniform(prngs[2]),
        )

        # Make optimizer.
        params: Union[Parameters, ExponentialCoordinatesParameters]
        optimizer = optax.adam(learning_rate=learning_rate)
        if algorithm == "tangent_space":
            # Initialize gradient statistics as on the tangent space.
            params = global_params
            optimizer_state = optimizer.init(jaxlie.manifold.zero_tangents(params))
        elif algorithm == "projected":
            # Initialize gradient statistics directly in quaternion space.
            params = global_params
            optimizer_state = optimizer.init(params)
        elif algorithm == "exponential_coordinates":
            # Switch to a log-space parameterization.
            params = ExponentialCoordinatesParameters.from_global(global_params)
            optimizer_state = optimizer.init(params)
        else:
            assert_never(algorithm)

        return State(
            params=params,
            optimizer=optimizer,
            optimizer_state=optimizer_state,
            algorithm=algorithm,
        )

    @jax.jit
    def step(self: State) -> Tuple[jax.Array, State]:
        """Take one ADAM optimization step."""

        if self.algorithm == "tangent_space":
            # ADAM step on manifold.
            #
            # `jaxlie.manifold.value_and_grad()` is a drop-in replacement for
            # `jax.value_and_grad()`, but for Lie group instances computes gradients on
            # the tangent space.
            loss, grads = jaxlie.manifold.value_and_grad(compute_loss)(self.params)
            updates, new_optimizer_state = self.optimizer.update(
                grads,
                self.optimizer_state,
                self.params,
            )
            new_params = jaxlie.manifold.rplus(self.params, updates)

        elif self.algorithm == "projected":
            # Projection-based approach.
            loss, grads = jax.value_and_grad(compute_loss)(self.params)
            updates, new_optimizer_state = self.optimizer.update(
                grads,
                self.optimizer_state,
                self.params,
            )
            new_params = optax.apply_updates(self.params, updates)

            # Project back to manifold.
            new_params = jaxlie.manifold.normalize_all(new_params)

        elif self.algorithm == "exponential_coordinates":
            # If we parameterize with exponential coordinates, we can
            loss, grads = jax.value_and_grad(compute_loss)(self.params)
            updates, new_optimizer_state = self.optimizer.update(
                grads,
                self.optimizer_state,
                self.params,
            )
            new_params = optax.apply_updates(self.params, updates)

        else:
            assert assert_never(self.algorithm)

        # Return updated structure.
        with jdc.copy_and_mutate(self, validate=True) as new_state:
            new_state.params = new_params
            new_state.optimizer_state = new_optimizer_state

        return loss, new_state


def run_experiment(
    algorithm: Algorithm, learning_rate: float, train_steps: int
) -> List[float]:
    """Run the optimization problem, either using a tangent-space approach or via
    projection."""

    print(algorithm)
    state = State.initialize(algorithm, learning_rate)
    state.step()  # Don't include JIT compile in timing.

    start_time = time.time()
    losses = []
    for i in range(train_steps):
        loss, state = state.step()
        if i % 20 == 0:
            print(f"\t(step {i:03d}) Loss", loss, flush=True)
        losses.append(float(loss))
    print()
    print(f"\tConverged in {time.time() - start_time} seconds")
    print()
    print("\tAfter optimization, the following transforms should be consistent:")
    print(f"\t\t{state.params.T_ab @ state.params.T_bc=}")
    print(f"\t\t{state.params.T_ca.inverse()=}")

    return losses


def main(train_steps: int = 1000, learning_rate: float = 1e-1) -> None:
    """Run pose optimization experiments.

    Args:
        train_steps: Number of training steps to take.
        learning_rate: Learning rate for our ADAM optimizers.
    """
    xs = range(train_steps)

    algorithms: Tuple[Algorithm, ...] = (
        "tangent_space",
        "projected",
        "exponential_coordinates",
    )
    for algorithm in algorithms:
        plt.plot(
            xs,
            run_experiment(algorithm, learning_rate, train_steps),
            label=algorithm,
        )
        print()
    plt.yscale("log", base=2)
    plt.legend()
    plt.show()


if __name__ == "__main__":
    tyro.cli(main)