Non-Euclidean variables#

Many variables in optimization live on non-Euclidean manifolds. Common examples include:

  • Rotations: SO(2), SO(3), quaternions

  • Rigid transformations: SE(2), SE(3)

  • Unit vectors: points on spheres

  • Probability simplices: points that sum to 1

  • Low-rank matrices, positive-definite matrices, etc.

Naive Euclidean optimization of these variables often fails because it ignores geometric constraints like unit norms or orthogonality.

jaxls supports manifold optimization via custom retract_fn and tangent_dim parameters when defining variables.

Features used:

  • SO3Var for SO(3) rotation variables

  • Var with custom retract_fn and tangent_dim for manifold variables

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import jax
import jax.numpy as jnp
import jaxlie
import jaxls

Example: rotation averaging#

Rotations in 3D form the SO(3) manifold. Naive Euclidean approaches fail because quaternions must have unit norm, but averaging quaternion components produces non-unit results that require renormalization, losing the geometric structure of the problem.

Let’s generate some noisy rotation measurements around a ground truth:

# Ground truth rotation (45 degrees around a diagonal axis).
ground_truth = jaxlie.SO3.exp(jnp.array([0.3, 0.4, 0.5]))

# Generate noisy measurements by perturbing with small random rotations.
num_measurements = 10
noise_std = 0.1  # Radians.

key = jax.random.PRNGKey(42)
noise_tangents = jax.random.normal(key, (num_measurements, 3)) * noise_std

# Apply noise: measurement = ground_truth @ exp(noise).
measurements = jax.vmap(lambda delta: ground_truth @ jaxlie.SO3.exp(delta))(
    noise_tangents
)

# Simulate antipodal ambiguity: negate half of the quaternions.
# (q and -q represent the same rotation, but naive averaging doesn't know this)
# When half are negated, they nearly cancel out!
flip_mask = jnp.arange(num_measurements) < num_measurements // 2
flipped_wxyz = jnp.where(flip_mask[:, None], -measurements.wxyz, measurements.wxyz)
measurements_flipped = jaxlie.SO3(wxyz=flipped_wxyz)

print(f"Ground truth quaternion: {ground_truth.wxyz}")
print(f"Generated {num_measurements} noisy measurements")
print(f"Negated {int(flip_mask.sum())} quaternions to simulate antipodal ambiguity")
Ground truth quaternion: [0.9381483  0.14689447 0.1958593  0.24482411]
Generated 10 noisy measurements
Negated 5 quaternions to simulate antipodal ambiguity

Naive Euclidean averaging (fails)#

A common mistake is to average quaternion components directly:

# Average the quaternion components (wrong!).
# This fails badly when quaternions have mixed signs due to antipodal ambiguity.
avg_quaternion = jnp.mean(measurements_flipped.wxyz, axis=0)

print(f"Averaged quaternion: {avg_quaternion}")
print(f"Quaternion norm: {jnp.linalg.norm(avg_quaternion):.4f} (should be 1.0)")

# Even if we renormalize, this approach is geometrically incorrect.
renormalized = avg_quaternion / (jnp.linalg.norm(avg_quaternion) + 1e-8)
naive_result = jaxlie.SO3(wxyz=renormalized)

# Compute geodesic error (rotation angle between result and ground truth).
naive_error = jnp.linalg.norm((naive_result.inverse() @ ground_truth).log())
print(
    f"\nNaive approach geodesic error: {float(naive_error):.4f} rad ({float(jnp.rad2deg(naive_error)):.2f} deg)"
)
Averaged quaternion: [ 0.00473162 -0.01714533 -0.00761085 -0.00510501]
Quaternion norm: 0.0200 (should be 1.0)

Naive approach geodesic error: 3.0596 rad (175.30 deg)

Manifold optimization#

The standard approach is to optimize on the manifold using:

  1. Tangent space: At each point on the manifold, there’s a local linear approximation (the tangent space). For SO(3), tangent spaces are 3-dimensional; we can parameterize local updates using axis-angle vectors.

  2. Retraction: A function retract(x, delta) that maps a point on the manifold plus a tangent space update back to a valid point on the manifold. For SO(3), we use the exponential map: \(R_{\text{new}} = R_{\text{current}} \cdot \exp(\delta)\), where \(\delta \in \mathbb{R}^3\) is an axis-angle perturbation.

A valid retraction must satisfy:

  • Identity at zero: retract(x, 0) = x

  • Stays on manifold: output always satisfies the manifold constraints

  • Smoothness: the function should be differentiable

The solver optimizes in the tangent space (which is Euclidean), then uses retraction to update the manifold variable.

Using SO3Var#

jaxls provides built-in Lie group variables through jaxlie. The SO3Var class handles manifold optimization automatically:

# SO3Var is defined as:
# class SO3Var(.
#     Var[jaxlie.SO3],.
#     default_factory=jaxlie.SO3.identity,.
#     retract_fn=jaxlie.manifold.rplus,  # R_new = R_old @ exp(delta).
#     tangent_dim=3,  # SO(3) has 3 degrees of freedom.
# ): ...

rotation_var = jaxls.SO3Var(id=0)

print(f"SO3Var tangent dimension: {jaxls.SO3Var.tangent_dim}")
SO3Var tangent dimension: 3

Rotation averaging cost#

For rotation averaging, we minimize the sum of squared geodesic distances to each measurement. The geodesic distance on SO(3) is the angle of rotation between two orientations:

\[d(R, R_i) = \|\log(R^{-1} R_i)\|\]

The log map returns a 3D axis-angle vector whose norm is the rotation angle.

@jaxls.Cost.factory
def rotation_cost(
    vals: jaxls.VarValues,
    var: jaxls.SO3Var,
    measurement: jaxlie.SO3,
) -> jax.Array:
    """Penalize deviation from measurement using geodesic distance.

    Args:
        vals: Current variable values.
        var: The rotation variable to optimize.
        measurement: Target rotation measurement.

    Returns:
        3D residual vector (axis-angle difference).
    """
    R = vals[var]
    # Geodesic error: log(R^{-1} @ measurement).
    return (R.inverse() @ measurement).log()

Solving with manifold optimization#

Create costs for each measurement and solve:

# Create batched costs for all measurements.
# Note: we use the original measurements (with consistent quaternion signs).
# The manifold approach works correctly regardless of quaternion sign.
costs = [
    rotation_cost(
        jaxls.SO3Var(id=jnp.zeros(num_measurements, dtype=jnp.int32)),
        measurements,  # Original measurements work fine.
    )
]

# Start from identity rotation.
initial_vals = jaxls.VarValues.make([rotation_var])

# Build the problem.
problem = jaxls.LeastSquaresProblem(costs, [rotation_var])

# Visualize the problem structure structure.
problem.show()
# Analyze and solve.
problem = problem.analyze()
solution = problem.solve(initial_vals)
INFO     | Building optimization problem with 10 terms and 1 variables: 10 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 10 costs, 1 variables each: rotation_cost
INFO     |  step #0: cost=5.6207 lambd=0.0005 inexact_tol=1.0e-02
INFO     |      - rotation_cost(10): 5.62067 (avg 0.18736)
INFO     |      accepted=True ATb_norm=9.08e+00 cost_prev=5.6207 cost_new=0.2267
INFO     |  step #1: cost=0.2267 lambd=0.0003 inexact_tol=1.0e-02
INFO     |      - rotation_cost(10): 0.22673 (avg 0.00756)
INFO     |      accepted=True ATb_norm=2.78e-02 cost_prev=0.2267 cost_new=0.2267
INFO     |  step #2: cost=0.2267 lambd=0.0001 inexact_tol=8.4e-06
INFO     |      - rotation_cost(10): 0.22668 (avg 0.00756)
INFO     |      accepted=True ATb_norm=8.53e-05 cost_prev=0.2267 cost_new=0.2267
INFO     | Terminated @ iteration #3: cost=0.2267 criteria=[1 0 0], term_deltas=2.6e-07,3.5e-05,6.9e-06
# Extract result and compute error.
manifold_result = solution[rotation_var]
manifold_error = jnp.linalg.norm((manifold_result.inverse() @ ground_truth).log())

print(f"Ground truth quaternion:  {ground_truth.wxyz}")
print(f"Manifold result quaternion: {manifold_result.wxyz}")
print(
    f"\nManifold approach geodesic error: {float(manifold_error):.4f} rad ({float(jnp.rad2deg(manifold_error)):.2f} deg)"
)
print(
    f"Naive approach geodesic error:    {float(naive_error):.4f} rad ({float(jnp.rad2deg(naive_error)):.2f} deg)"
)
Ground truth quaternion:  [0.9381483  0.14689447 0.1958593  0.24482411]
Manifold result quaternion: [0.93323416 0.13931717 0.20464188 0.26035807]

Manifold approach geodesic error: 0.0400 rad (2.29 deg)
Naive approach geodesic error:    3.0596 rad (175.30 deg)

Defining custom manifold variables#

To create your own manifold variable, subclass Var with:

  • default_factory: A callable returning the default value

  • retract_fn: A function (current_value, tangent_delta) -> new_value that applies a tangent space update, where tangent_delta is a 1D array with shape (tangent_dim,)

  • tangent_dim: The dimension of the local tangent space

Here’s how SO3Var is implemented (you can define similar variables for other manifolds):

# Custom SO(3) variable (equivalent to jaxls.SO3Var).
class CustomSO3Var(
    jaxls.Var[jaxlie.SO3],
    default_factory=jaxlie.SO3.identity,
    retract_fn=jaxlie.manifold.rplus,  # rplus(R, delta) = R @ SO3.exp(delta).
    tangent_dim=3,
):
    """Custom SO(3) rotation variable."""


# Example: Unit sphere manifold (S^2).
def sphere_retract(point: jax.Array, delta: jax.Array) -> jax.Array:
    """Retract from tangent plane back to sphere.

    Args:
        point: Current point on unit sphere (3,).
        delta: Tangent vector in local coordinates (2,).

    Returns:
        New point on unit sphere (3,).
    """
    # Build orthonormal basis for tangent plane.
    # Choose a vector not parallel to point.
    aux = jnp.where(
        jnp.abs(point[0]) < 0.9, jnp.array([1.0, 0.0, 0.0]), jnp.array([0.0, 1.0, 0.0])
    )
    e1 = aux - jnp.dot(aux, point) * point
    e1 = e1 / jnp.linalg.norm(e1)
    e2 = jnp.cross(point, e1)

    # Move in tangent plane and project back to sphere.
    new_point = point + delta[0] * e1 + delta[1] * e2
    return new_point / jnp.linalg.norm(new_point)


class UnitSphereVar(
    jaxls.Var[jax.Array],
    default_factory=lambda: jnp.array([0.0, 0.0, 1.0]),
    retract_fn=sphere_retract,
    tangent_dim=2,  # Sphere is 2D manifold embedded in 3D.
):
    """Point on the unit sphere S^2."""


print(f"CustomSO3Var tangent_dim: {CustomSO3Var.tangent_dim}")
print(f"UnitSphereVar tangent_dim: {UnitSphereVar.tangent_dim}")
CustomSO3Var tangent_dim: 3
UnitSphereVar tangent_dim: 2

Visualization#

Hide code cell source

import contextlib
import io
import numpy as np
import viser

# Create Viser server (suppress output).
with (
    contextlib.redirect_stdout(io.StringIO()),
    contextlib.redirect_stderr(io.StringIO()),
):
    server = viser.ViserServer(verbose=False)
server.scene.set_up_direction("+z")

# Set initial camera position for a good view of all three frames.
server.initial_camera.position = (0.6, -1.5, 0.6)
server.initial_camera.look_at = (0.6, 0.0, 0.0)


# Helper to add a coordinate frame.
def add_frame(
    name: str,
    R: jaxlie.SO3,
    position: tuple[float, float, float],
    scale: float = 0.25,
    opacity: float = 1.0,
) -> None:
    """Add a coordinate frame to the scene."""
    server.scene.add_frame(
        name,
        wxyz=np.array(R.wxyz),
        position=position,
        axes_length=scale,
        axes_radius=0.008 if opacity > 0.5 else 0.004,
    )


# Add measurement frames (smaller, representing noisy observations).
for i in range(num_measurements):
    R_meas = jaxlie.SO3(wxyz=measurements.wxyz[i])
    add_frame(
        f"/measurements/frame_{i}", R_meas, (0.0, 0.0, 0.0), scale=0.15, opacity=0.3
    )

# Add ground truth frame.
add_frame("/ground_truth", ground_truth, (0.0, 0.0, 0.0), scale=0.25)

# Add manifold result frame (offset for visibility).
add_frame("/manifold_result", manifold_result, (0.6, 0.0, 0.0), scale=0.25)

# Add naive result frame.
add_frame("/naive_result", naive_result, (1.2, 0.0, 0.0), scale=0.25)

# Add labels.
server.scene.add_label(
    "/labels/ground_truth",
    text="Ground Truth\n+ Measurements",
    position=(0.0, 0.0, -0.2),
)
server.scene.add_label(
    "/labels/manifold",
    text=f"Manifold\n({float(jnp.rad2deg(manifold_error)):.1f}° error)",
    position=(0.6, 0.0, -0.2),
)
server.scene.add_label(
    "/labels/naive",
    text=f"Naive\n({float(jnp.rad2deg(naive_error)):.1f}° error)",
    position=(1.2, 0.0, -0.2),
)

# Display inline.
server.scene.show(height=400)
╭────── viser (listening *:8082) ───────╮
│             ╷                         │
│   HTTP      │ http://localhost:8082   │
│   Websocket │ ws://localhost:8082     │
│             ╵                         │
╰───────────────────────────────────────╯

The manifold-aware approach consistently achieves lower error. The naive quaternion averaging fails because it ignores the geometry of the manifold.

Key points:

  • Define retract_fn to map tangent space updates back to the manifold

  • Set tangent_dim to the local degrees of freedom (not the ambient dimension)

  • Use geodesic costs that respect manifold geometry (e.g., the log map for Lie groups)

Built-in Lie group variables:

For custom manifolds (spheres, simplices, etc.), define your own variable class with appropriate retract_fn and tangent_dim as shown above.