Non-Euclidean variables#
Many variables in optimization live on non-Euclidean manifolds. Common examples include:
Rotations: SO(2), SO(3), quaternions
Rigid transformations: SE(2), SE(3)
Unit vectors: points on spheres
Probability simplices: points that sum to 1
Low-rank matrices, positive-definite matrices, etc.
Naive Euclidean optimization of these variables often fails because it ignores geometric constraints like unit norms or orthogonality.
jaxls supports manifold optimization via custom retract_fn and tangent_dim parameters when defining variables.
Features used:
SO3Varfor SO(3) rotation variablesVarwith customretract_fnandtangent_dimfor manifold variables
import jax
import jax.numpy as jnp
import jaxlie
import jaxls
Example: rotation averaging#
Rotations in 3D form the SO(3) manifold. Naive Euclidean approaches fail because quaternions must have unit norm, but averaging quaternion components produces non-unit results that require renormalization, losing the geometric structure of the problem.
Let’s generate some noisy rotation measurements around a ground truth:
# Ground truth rotation (45 degrees around a diagonal axis).
ground_truth = jaxlie.SO3.exp(jnp.array([0.3, 0.4, 0.5]))
# Generate noisy measurements by perturbing with small random rotations.
num_measurements = 10
noise_std = 0.1 # Radians.
key = jax.random.PRNGKey(42)
noise_tangents = jax.random.normal(key, (num_measurements, 3)) * noise_std
# Apply noise: measurement = ground_truth @ exp(noise).
measurements = jax.vmap(lambda delta: ground_truth @ jaxlie.SO3.exp(delta))(
noise_tangents
)
# Simulate antipodal ambiguity: negate half of the quaternions.
# (q and -q represent the same rotation, but naive averaging doesn't know this)
# When half are negated, they nearly cancel out!
flip_mask = jnp.arange(num_measurements) < num_measurements // 2
flipped_wxyz = jnp.where(flip_mask[:, None], -measurements.wxyz, measurements.wxyz)
measurements_flipped = jaxlie.SO3(wxyz=flipped_wxyz)
print(f"Ground truth quaternion: {ground_truth.wxyz}")
print(f"Generated {num_measurements} noisy measurements")
print(f"Negated {int(flip_mask.sum())} quaternions to simulate antipodal ambiguity")
Ground truth quaternion: [0.9381483 0.14689447 0.1958593 0.24482411]
Generated 10 noisy measurements
Negated 5 quaternions to simulate antipodal ambiguity
Naive Euclidean averaging (fails)#
A common mistake is to average quaternion components directly:
# Average the quaternion components (wrong!).
# This fails badly when quaternions have mixed signs due to antipodal ambiguity.
avg_quaternion = jnp.mean(measurements_flipped.wxyz, axis=0)
print(f"Averaged quaternion: {avg_quaternion}")
print(f"Quaternion norm: {jnp.linalg.norm(avg_quaternion):.4f} (should be 1.0)")
# Even if we renormalize, this approach is geometrically incorrect.
renormalized = avg_quaternion / (jnp.linalg.norm(avg_quaternion) + 1e-8)
naive_result = jaxlie.SO3(wxyz=renormalized)
# Compute geodesic error (rotation angle between result and ground truth).
naive_error = jnp.linalg.norm((naive_result.inverse() @ ground_truth).log())
print(
f"\nNaive approach geodesic error: {float(naive_error):.4f} rad ({float(jnp.rad2deg(naive_error)):.2f} deg)"
)
Averaged quaternion: [ 0.00473162 -0.01714533 -0.00761085 -0.00510501]
Quaternion norm: 0.0200 (should be 1.0)
Naive approach geodesic error: 3.0596 rad (175.30 deg)
Manifold optimization#
The standard approach is to optimize on the manifold using:
Tangent space: At each point on the manifold, there’s a local linear approximation (the tangent space). For SO(3), tangent spaces are 3-dimensional; we can parameterize local updates using axis-angle vectors.
Retraction: A function
retract(x, delta)that maps a point on the manifold plus a tangent space update back to a valid point on the manifold. For SO(3), we use the exponential map: \(R_{\text{new}} = R_{\text{current}} \cdot \exp(\delta)\), where \(\delta \in \mathbb{R}^3\) is an axis-angle perturbation.
A valid retraction must satisfy:
Identity at zero:
retract(x, 0) = xStays on manifold: output always satisfies the manifold constraints
Smoothness: the function should be differentiable
The solver optimizes in the tangent space (which is Euclidean), then uses retraction to update the manifold variable.
Using SO3Var#
jaxls provides built-in Lie group variables through jaxlie. The SO3Var class handles manifold optimization automatically:
# SO3Var is defined as:
# class SO3Var(.
# Var[jaxlie.SO3],.
# default_factory=jaxlie.SO3.identity,.
# retract_fn=jaxlie.manifold.rplus, # R_new = R_old @ exp(delta).
# tangent_dim=3, # SO(3) has 3 degrees of freedom.
# ): ...
rotation_var = jaxls.SO3Var(id=0)
print(f"SO3Var tangent dimension: {jaxls.SO3Var.tangent_dim}")
SO3Var tangent dimension: 3
Rotation averaging cost#
For rotation averaging, we minimize the sum of squared geodesic distances to each measurement. The geodesic distance on SO(3) is the angle of rotation between two orientations:
The log map returns a 3D axis-angle vector whose norm is the rotation angle.
@jaxls.Cost.factory
def rotation_cost(
vals: jaxls.VarValues,
var: jaxls.SO3Var,
measurement: jaxlie.SO3,
) -> jax.Array:
"""Penalize deviation from measurement using geodesic distance.
Args:
vals: Current variable values.
var: The rotation variable to optimize.
measurement: Target rotation measurement.
Returns:
3D residual vector (axis-angle difference).
"""
R = vals[var]
# Geodesic error: log(R^{-1} @ measurement).
return (R.inverse() @ measurement).log()
Solving with manifold optimization#
Create costs for each measurement and solve:
# Create batched costs for all measurements.
# Note: we use the original measurements (with consistent quaternion signs).
# The manifold approach works correctly regardless of quaternion sign.
costs = [
rotation_cost(
jaxls.SO3Var(id=jnp.zeros(num_measurements, dtype=jnp.int32)),
measurements, # Original measurements work fine.
)
]
# Start from identity rotation.
initial_vals = jaxls.VarValues.make([rotation_var])
# Build the problem.
problem = jaxls.LeastSquaresProblem(costs, [rotation_var])
# Visualize the problem structure structure.
problem.show()
# Analyze and solve.
problem = problem.analyze()
solution = problem.solve(initial_vals)
INFO | Building optimization problem with 10 terms and 1 variables: 10 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO | Vectorizing group with 10 costs, 1 variables each: rotation_cost
INFO | step #0: cost=5.6207 lambd=0.0005 inexact_tol=1.0e-02
INFO | - rotation_cost(10): 5.62067 (avg 0.18736)
INFO | accepted=True ATb_norm=9.08e+00 cost_prev=5.6207 cost_new=0.2267
INFO | step #1: cost=0.2267 lambd=0.0003 inexact_tol=1.0e-02
INFO | - rotation_cost(10): 0.22673 (avg 0.00756)
INFO | accepted=True ATb_norm=2.78e-02 cost_prev=0.2267 cost_new=0.2267
INFO | step #2: cost=0.2267 lambd=0.0001 inexact_tol=8.4e-06
INFO | - rotation_cost(10): 0.22668 (avg 0.00756)
INFO | accepted=True ATb_norm=8.53e-05 cost_prev=0.2267 cost_new=0.2267
INFO | Terminated @ iteration #3: cost=0.2267 criteria=[1 0 0], term_deltas=2.6e-07,3.5e-05,6.9e-06
# Extract result and compute error.
manifold_result = solution[rotation_var]
manifold_error = jnp.linalg.norm((manifold_result.inverse() @ ground_truth).log())
print(f"Ground truth quaternion: {ground_truth.wxyz}")
print(f"Manifold result quaternion: {manifold_result.wxyz}")
print(
f"\nManifold approach geodesic error: {float(manifold_error):.4f} rad ({float(jnp.rad2deg(manifold_error)):.2f} deg)"
)
print(
f"Naive approach geodesic error: {float(naive_error):.4f} rad ({float(jnp.rad2deg(naive_error)):.2f} deg)"
)
Ground truth quaternion: [0.9381483 0.14689447 0.1958593 0.24482411]
Manifold result quaternion: [0.93323416 0.13931717 0.20464188 0.26035807]
Manifold approach geodesic error: 0.0400 rad (2.29 deg)
Naive approach geodesic error: 3.0596 rad (175.30 deg)
Defining custom manifold variables#
To create your own manifold variable, subclass Var with:
default_factory: A callable returning the default valueretract_fn: A function(current_value, tangent_delta) -> new_valuethat applies a tangent space update, wheretangent_deltais a 1D array with shape(tangent_dim,)tangent_dim: The dimension of the local tangent space
Here’s how SO3Var is implemented (you can define similar variables for other manifolds):
# Custom SO(3) variable (equivalent to jaxls.SO3Var).
class CustomSO3Var(
jaxls.Var[jaxlie.SO3],
default_factory=jaxlie.SO3.identity,
retract_fn=jaxlie.manifold.rplus, # rplus(R, delta) = R @ SO3.exp(delta).
tangent_dim=3,
):
"""Custom SO(3) rotation variable."""
# Example: Unit sphere manifold (S^2).
def sphere_retract(point: jax.Array, delta: jax.Array) -> jax.Array:
"""Retract from tangent plane back to sphere.
Args:
point: Current point on unit sphere (3,).
delta: Tangent vector in local coordinates (2,).
Returns:
New point on unit sphere (3,).
"""
# Build orthonormal basis for tangent plane.
# Choose a vector not parallel to point.
aux = jnp.where(
jnp.abs(point[0]) < 0.9, jnp.array([1.0, 0.0, 0.0]), jnp.array([0.0, 1.0, 0.0])
)
e1 = aux - jnp.dot(aux, point) * point
e1 = e1 / jnp.linalg.norm(e1)
e2 = jnp.cross(point, e1)
# Move in tangent plane and project back to sphere.
new_point = point + delta[0] * e1 + delta[1] * e2
return new_point / jnp.linalg.norm(new_point)
class UnitSphereVar(
jaxls.Var[jax.Array],
default_factory=lambda: jnp.array([0.0, 0.0, 1.0]),
retract_fn=sphere_retract,
tangent_dim=2, # Sphere is 2D manifold embedded in 3D.
):
"""Point on the unit sphere S^2."""
print(f"CustomSO3Var tangent_dim: {CustomSO3Var.tangent_dim}")
print(f"UnitSphereVar tangent_dim: {UnitSphereVar.tangent_dim}")
CustomSO3Var tangent_dim: 3
UnitSphereVar tangent_dim: 2
Visualization#
╭────── viser (listening *:8082) ───────╮ │ ╷ │ │ HTTP │ http://localhost:8082 │ │ Websocket │ ws://localhost:8082 │ │ ╵ │ ╰───────────────────────────────────────╯
The manifold-aware approach consistently achieves lower error. The naive quaternion averaging fails because it ignores the geometry of the manifold.
Key points:
Define
retract_fnto map tangent space updates back to the manifoldSet
tangent_dimto the local degrees of freedom (not the ambient dimension)Use geodesic costs that respect manifold geometry (e.g., the log map for Lie groups)
Built-in Lie group variables:
jaxls.SO2Var,jaxls.SO3Varfor rotationsjaxls.SE2Var,jaxls.SE3Varfor rigid transformations
For custom manifolds (spheres, simplices, etc.), define your own variable class with appropriate retract_fn and tangent_dim as shown above.