Covariance estimation#

Computing uncertainty estimates from nonlinear least squares solutions in jaxls.

After solving a nonlinear least squares problem, we often want to quantify uncertainty in the estimated variables. The covariance matrix \((J^T J)^{-1}\) provides this information, representing the tangent-space uncertainty of each variable.

Features used:

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import jax
import jax.numpy as jnp
import jaxlie
import jaxls
import numpy as np

Background#

For a nonlinear least squares problem that minimizes \(\sum_i \|r_i(x)\|^2\), the covariance of the solution at a local minimum is approximated by:

\[\Sigma = (J^T J)^{-1} \sigma^2\]

where:

  • \(J\) is the Jacobian of residuals at the solution

  • \(\sigma^2\) is the residual variance, estimated as \(\|r\|^2 / (m - n)\) where \(m\) is the number of residuals and \(n\) is the number of parameters

For manifold variables (like SE3 poses), this covariance lives in the tangent space.

Simple example: 2D localization#

Let’s start with a simple problem: estimating a 2D pose from noisy measurements.

# Ground truth pose.
true_pose = jaxlie.SE2.from_xy_theta(1.0, 2.0, jnp.pi / 4)

# Generate noisy measurements.
np.random.seed(42)
n_measurements = 10
measurement_noise_std = 0.1

# Noisy position measurements (x, y).
measured_positions = (
    true_pose.translation()
    + jnp.array(np.random.randn(n_measurements, 2)) * measurement_noise_std
)

# Noisy orientation measurements (as unit vectors).
true_angle = jnp.arctan2(
    true_pose.rotation().as_matrix()[1, 0], true_pose.rotation().as_matrix()[0, 0]
)
measured_angles = true_angle + np.random.randn(n_measurements) * measurement_noise_std

print(
    f"True pose: x={true_pose.translation()[0]:.3f}, y={true_pose.translation()[1]:.3f}, theta={float(true_angle):.3f}"
)
True pose: x=1.000, y=2.000, theta=0.785
# Define cost functions.
@jaxls.Cost.factory
def position_cost(
    vals: jaxls.VarValues, pose: jaxls.SE2Var, measured_xy: jax.Array
) -> jax.Array:
    """Residual: difference between pose position and measurement."""
    return vals[pose].translation() - measured_xy


@jaxls.Cost.factory
def orientation_cost(
    vals: jaxls.VarValues, pose: jaxls.SE2Var, measured_theta: jax.Array
) -> jax.Array:
    """Residual: difference in orientation (as log of rotation)."""
    R = vals[pose].rotation()
    R_measured = jaxlie.SO2.from_radians(measured_theta)
    return (R @ R_measured.inverse()).log()
# Build and solve the problem.
pose_var = jaxls.SE2Var(id=0)

# Batch costs.
costs = [
    position_cost(
        jaxls.SE2Var(id=jnp.zeros(n_measurements, dtype=jnp.int32)),
        measured_positions,
    ),
    orientation_cost(
        jaxls.SE2Var(id=jnp.zeros(n_measurements, dtype=jnp.int32)),
        measured_angles,
    ),
]

initial_vals = jaxls.VarValues.make([pose_var.with_value(jaxlie.SE2.identity())])

problem = jaxls.LeastSquaresProblem(costs, [pose_var]).analyze()
solution = problem.solve(initial_vals)

estimated_pose = solution[pose_var]
est_angle = jnp.arctan2(
    estimated_pose.rotation().as_matrix()[1, 0],
    estimated_pose.rotation().as_matrix()[0, 0],
)
print(
    f"Estimated pose: x={estimated_pose.translation()[0]:.3f}, y={estimated_pose.translation()[1]:.3f}, theta={float(est_angle):.3f}"
)
INFO     | Building optimization problem with 20 terms and 1 variables: 20 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 10 costs, 1 variables each: orientation_cost
INFO     | Vectorizing group with 10 costs, 1 variables each: position_cost
INFO     |  step #0: cost=55.0646 lambd=0.0005 inexact_tol=1.0e-02
INFO     |      - orientation_cost(10): 5.88457 (avg 0.58846)
INFO     |      - position_cost(10): 49.18004 (avg 2.45900)
INFO     |      accepted=True ATb_norm=2.90e+01 cost_prev=55.0646 cost_new=7.1428
INFO     |  step #1: cost=7.1428 lambd=0.0003 inexact_tol=1.0e-02
INFO     |      - orientation_cost(10): 0.05962 (avg 0.00596)
INFO     |      - position_cost(10): 7.08317 (avg 0.35416)
INFO     |      accepted=True ATb_norm=1.03e+01 cost_prev=7.1428 cost_new=0.2347
INFO     |  step #2: cost=0.2347 lambd=0.0001 inexact_tol=1.0e-02
INFO     |      - orientation_cost(10): 0.05962 (avg 0.00596)
INFO     |      - position_cost(10): 0.17508 (avg 0.00875)
INFO     |      accepted=False ATb_norm=1.75e-04 cost_prev=0.2347 cost_new=0.2347
INFO     | Terminated @ iteration #3: cost=0.2347 criteria=[1 0 0], term_deltas=0.0e+00,1.7e-04,5.8e-06
Estimated pose: x=0.982, y=1.984, theta=0.763

Computing covariance#

Use make_covariance_estimator() to create a covariance estimator, then extract blocks with covariance():

# Create covariance estimator (default: conjugate gradient solver).
estimator = problem.make_covariance_estimator(solution)

# Get the 3x3 covariance matrix for the SE2 pose (in tangent space).
# Tangent space is [vx, vy, omega] for SE2.
cov = estimator.covariance(pose_var)

print("Covariance matrix (tangent space):")
print(cov)
print(
    f"\nStandard deviations: x={jnp.sqrt(cov[0, 0]):.4f}, y={jnp.sqrt(cov[1, 1]):.4f}, theta={jnp.sqrt(cov[2, 2]):.4f}"
)
Covariance matrix (tangent space):
[[ 8.6924818e-04 -5.4531440e-12  0.0000000e+00]
 [-5.4531418e-12  8.6924795e-04  0.0000000e+00]
 [ 0.0000000e+00  0.0000000e+00  8.6924812e-04]]

Standard deviations: x=0.0295, y=0.0295, theta=0.0295

Hide code cell source

import plotly.graph_objects as go
from IPython.display import HTML


def make_ellipse_trace(
    center: tuple[float, float],
    cov_2d: jax.Array,
    n_std: float = 2.0,
    name: str = "Uncertainty",
    color: str = "#2196F3",
    dash: str = "solid",
) -> go.Scatter:
    """Create a 2D uncertainty ellipse trace."""
    # Eigendecomposition for ellipse axes.
    eigvals, eigvecs = jnp.linalg.eigh(cov_2d)
    # Scale by number of standard deviations.
    width = 2 * n_std * jnp.sqrt(eigvals[0])
    height = 2 * n_std * jnp.sqrt(eigvals[1])
    angle = jnp.arctan2(eigvecs[1, 0], eigvecs[0, 0])

    # Generate ellipse points.
    t = jnp.linspace(0, 2 * jnp.pi, 100)
    x_ellipse = width / 2 * jnp.cos(t)
    y_ellipse = height / 2 * jnp.sin(t)

    # Rotate.
    cos_a, sin_a = jnp.cos(angle), jnp.sin(angle)
    x_rot = cos_a * x_ellipse - sin_a * y_ellipse + center[0]
    y_rot = sin_a * x_ellipse + cos_a * y_ellipse + center[1]

    return go.Scatter(
        x=x_rot,
        y=y_rot,
        mode="lines",
        name=name,
        line=dict(color=color, width=2, dash=dash),
        fill="toself",
        fillcolor=f"rgba({int(color[1:3], 16)}, {int(color[3:5], 16)}, {int(color[5:7], 16)}, 0.2)",
    )


# Extract 2D position covariance (first 2x2 block).
cov_xy = cov[:2, :2]

fig = go.Figure()

# Measurements.
fig.add_trace(
    go.Scatter(
        x=measured_positions[:, 0],
        y=measured_positions[:, 1],
        mode="markers",
        marker=dict(size=8, color="#9E9E9E"),
        name="Measurements",
    )
)

# True pose.
fig.add_trace(
    go.Scatter(
        x=[true_pose.translation()[0]],
        y=[true_pose.translation()[1]],
        mode="markers",
        marker=dict(size=12, color="#4CAF50", symbol="star"),
        name="True pose",
    )
)

# Estimated pose.
fig.add_trace(
    go.Scatter(
        x=[estimated_pose.translation()[0]],
        y=[estimated_pose.translation()[1]],
        mode="markers",
        marker=dict(size=12, color="#2196F3", symbol="circle"),
        name="Estimated pose",
    )
)

# 2-sigma uncertainty ellipse.
fig.add_trace(
    make_ellipse_trace(
        (
            float(estimated_pose.translation()[0]),
            float(estimated_pose.translation()[1]),
        ),
        cov_xy,
        n_std=2.0,
        name="2σ uncertainty",
        color="#2196F3",
    )
)

fig.update_xaxes(title_text="x", scaleanchor="y", scaleratio=1)
fig.update_yaxes(title_text="y")
fig.update_layout(
    title="2D Pose Estimation with Uncertainty",
    height=500,
    margin=dict(t=60, b=40, l=60, r=40),
    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
)
HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

EKF measurement update as Gauss-Newton#

The measurement update step of an Extended Kalman Filter (EKF) is equivalent to a single Gauss-Newton iteration on a least squares problem with:

  • A prior cost (encoding the predicted state and covariance)

  • Measurement costs

The posterior covariance \((J^T J)^{-1}\) matches the EKF covariance update. This equivalence is exact for linear systems; for nonlinear systems, iterating to convergence (as in Iterated EKF) can improve accuracy.

# Simulate EKF-style sequential updates.
# Prior: initial pose estimate with uncertainty.
prior_pose = jaxlie.SE2.from_xy_theta(0.5, 1.5, jnp.pi / 6)
prior_cov = jnp.diag(jnp.array([0.5**2, 0.5**2, 0.3**2]))  # Prior uncertainty.


# For EKF equivalence, use Euclidean parameterization [x, y, theta].
class PoseVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(3)):
    """Pose as [x, y, theta] vector."""


@jaxls.Cost.factory
def euclidean_prior_cost(
    vals: jaxls.VarValues,
    pose: PoseVar,
    prior_mean: jax.Array,
    info_sqrt: jax.Array,
) -> jax.Array:
    """Prior cost in Euclidean coordinates."""
    error = vals[pose] - prior_mean
    return info_sqrt @ error


# Information matrix is inverse of covariance.
prior_info = jnp.linalg.inv(prior_cov)
prior_info_sqrt = jnp.linalg.cholesky(prior_info)
prior_mean = jnp.array([0.5, 1.5, jnp.pi / 6])
# Take a single measurement.
single_measurement_xy = measured_positions[0]
single_measurement_theta = measured_angles[0]

# Measurement noise covariance.
# Using higher noise than the data was generated with, to show the prior's influence.
meas_noise_std = 0.3
meas_info_sqrt = 1.0 / meas_noise_std  # Weight = 1/sigma for residuals.


@jaxls.Cost.factory
def euclidean_position_cost(
    vals: jaxls.VarValues,
    pose: PoseVar,
    measured_xy: jax.Array,
    weight: float,
) -> jax.Array:
    """Weighted position residual (Euclidean)."""
    return weight * (vals[pose][:2] - measured_xy)


@jaxls.Cost.factory
def euclidean_orientation_cost(
    vals: jaxls.VarValues,
    pose: PoseVar,
    measured_theta: jax.Array,
    weight: float,
) -> jax.Array:
    """Weighted orientation residual (Euclidean angle difference)."""
    # Simple angle difference (works for small angles).
    angle_diff = vals[pose][2] - measured_theta
    # Wrap to [-pi, pi].
    angle_diff = jnp.arctan2(jnp.sin(angle_diff), jnp.cos(angle_diff))
    return weight * angle_diff


# EKF update = single Gauss-Newton step + covariance.
ekf_pose_var = PoseVar(id=0)

ekf_costs = [
    euclidean_prior_cost(ekf_pose_var, prior_mean, prior_info_sqrt),
    euclidean_position_cost(ekf_pose_var, single_measurement_xy, meas_info_sqrt),
    euclidean_orientation_cost(ekf_pose_var, single_measurement_theta, meas_info_sqrt),
]

ekf_initial = jaxls.VarValues.make([ekf_pose_var.with_value(prior_mean)])
ekf_problem = jaxls.LeastSquaresProblem(ekf_costs, [ekf_pose_var]).analyze()

# Single iteration (like EKF).
ekf_solution = ekf_problem.solve(
    ekf_initial,
    trust_region=None,  # Pure Gauss-Newton (no trust region).
    termination=jaxls.TerminationConfig(max_iterations=1),
    verbose=False,
)

# Get posterior covariance.
ekf_estimator = ekf_problem.make_covariance_estimator(
    ekf_solution, scale_by_residual_variance=False
)
posterior_cov = ekf_estimator.covariance(ekf_pose_var)

ekf_pose_vec = ekf_solution[ekf_pose_var]

print("EKF-style update (1 Gauss-Newton iteration):")
print(f"  Prior:     x={prior_mean[0]:.3f}, y={prior_mean[1]:.3f}")
print(
    f"  Measurement: x={single_measurement_xy[0]:.3f}, y={single_measurement_xy[1]:.3f}"
)
print(f"  Posterior: x={ekf_pose_vec[0]:.3f}, y={ekf_pose_vec[1]:.3f}")
print(
    f"\nPrior uncertainty (std): x={jnp.sqrt(prior_cov[0, 0]):.3f}, y={jnp.sqrt(prior_cov[1, 1]):.3f}"
)
print(
    f"Posterior uncertainty (std): x={jnp.sqrt(posterior_cov[0, 0]):.3f}, y={jnp.sqrt(posterior_cov[1, 1]):.3f}"
)
INFO     | Building optimization problem with 3 terms and 1 variables: 3 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 1 costs, 1 variables each: euclidean_prior_cost
INFO     | Vectorizing group with 1 costs, 1 variables each: euclidean_position_cost
INFO     | Vectorizing group with 1 costs, 1 variables each: euclidean_orientation_cost
EKF-style update (1 Gauss-Newton iteration):
  Prior:     x=0.500, y=1.500
  Measurement: x=1.050, y=1.986
  Posterior: x=0.904, y=1.857

Prior uncertainty (std): x=0.500, y=0.500
Posterior uncertainty (std): x=0.257, y=0.257

Hide code cell source

# Visualize EKF update.
fig_ekf = go.Figure()

# Prior ellipse.
fig_ekf.add_trace(
    make_ellipse_trace(
        (float(prior_mean[0]), float(prior_mean[1])),
        prior_cov[:2, :2],
        n_std=2.0,
        name="Prior 2σ",
        color="#FF9800",
        dash="dash",
    )
)

# Prior mean.
fig_ekf.add_trace(
    go.Scatter(
        x=[prior_mean[0]],
        y=[prior_mean[1]],
        mode="markers",
        marker=dict(size=10, color="#FF9800", symbol="diamond"),
        name="Prior",
    )
)

# Measurement.
fig_ekf.add_trace(
    go.Scatter(
        x=[single_measurement_xy[0]],
        y=[single_measurement_xy[1]],
        mode="markers",
        marker=dict(size=10, color="#9E9E9E", symbol="cross"),
        name="Measurement",
    )
)

# Posterior ellipse.
fig_ekf.add_trace(
    make_ellipse_trace(
        (float(ekf_pose_vec[0]), float(ekf_pose_vec[1])),
        posterior_cov[:2, :2],
        n_std=2.0,
        name="Posterior 2σ",
        color="#2196F3",
    )
)

# Posterior mean.
fig_ekf.add_trace(
    go.Scatter(
        x=[ekf_pose_vec[0]],
        y=[ekf_pose_vec[1]],
        mode="markers",
        marker=dict(size=10, color="#2196F3", symbol="circle"),
        name="Posterior",
    )
)

# True pose.
fig_ekf.add_trace(
    go.Scatter(
        x=[true_pose.translation()[0]],
        y=[true_pose.translation()[1]],
        mode="markers",
        marker=dict(size=12, color="#4CAF50", symbol="star"),
        name="True pose",
    )
)

# Arrow from prior to posterior.
fig_ekf.add_annotation(
    x=float(ekf_pose_vec[0]),
    y=float(ekf_pose_vec[1]),
    ax=float(prior_mean[0]),
    ay=float(prior_mean[1]),
    xref="x",
    yref="y",
    axref="x",
    ayref="y",
    showarrow=True,
    arrowhead=2,
    arrowsize=1.5,
    arrowwidth=2,
    arrowcolor="#607D8B",
)

fig_ekf.update_xaxes(title_text="x", scaleanchor="y", scaleratio=1)
fig_ekf.update_yaxes(title_text="y")
fig_ekf.update_layout(
    title="EKF Update: Prior + Measurement → Posterior",
    height=500,
    margin=dict(t=60, b=40, l=60, r=40),
    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
)
HTML(fig_ekf.to_html(full_html=False, include_plotlyjs="cdn"))

Cross-covariance between variables#

For problems with multiple variables, we can compute cross-covariance blocks:

# Two-pose problem with relative constraint.
pose0 = jaxls.SE2Var(id=0)
pose1 = jaxls.SE2Var(id=1)


@jaxls.Cost.factory
def anchor_cost(
    vals: jaxls.VarValues, pose: jaxls.SE2Var, target: jaxlie.SE2
) -> jax.Array:
    """Anchor pose to a target."""
    return (vals[pose] @ target.inverse()).log()


@jaxls.Cost.factory
def relative_cost(
    vals: jaxls.VarValues,
    pose_a: jaxls.SE2Var,
    pose_b: jaxls.SE2Var,
    delta: jaxlie.SE2,
) -> jax.Array:
    """Relative pose constraint."""
    T_a = vals[pose_a]
    T_b = vals[pose_b]
    measured = T_a.inverse() @ T_b
    return (measured @ delta.inverse()).log()


two_pose_costs = [
    anchor_cost(pose0, jaxlie.SE2.from_xy_theta(0.0, 0.0, 0.0)),
    anchor_cost(pose1, jaxlie.SE2.from_xy_theta(2.0, 0.0, 0.0)),
    relative_cost(pose0, pose1, jaxlie.SE2.from_xy_theta(1.0, 0.0, 0.0)),
]

two_pose_initial = jaxls.VarValues.make(
    [
        pose0.with_value(jaxlie.SE2.identity()),
        pose1.with_value(jaxlie.SE2.from_xy_theta(1.5, 0.0, 0.0)),
    ]
)

two_pose_problem = jaxls.LeastSquaresProblem(two_pose_costs, [pose0, pose1]).analyze()
two_pose_solution = two_pose_problem.solve(two_pose_initial)

# Get covariance estimator.
two_pose_estimator = two_pose_problem.make_covariance_estimator(
    two_pose_solution, scale_by_residual_variance=False
)

# Marginal covariances.
cov_00 = two_pose_estimator.covariance(pose0)
cov_11 = two_pose_estimator.covariance(pose1)

# Cross-covariance.
cov_01 = two_pose_estimator.covariance(pose0, pose1)

print("Marginal covariance pose0:")
print(cov_00)
print("\nMarginal covariance pose1:")
print(cov_11)
print("\nCross-covariance pose0-pose1:")
print(cov_01)
INFO     | Building optimization problem with 3 terms and 2 variables: 3 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 1 costs, 2 variables each: relative_cost
INFO     | Vectorizing group with 2 costs, 1 variables each: anchor_cost
INFO     |  step #0: cost=0.5000 lambd=0.0005 inexact_tol=1.0e-02
INFO     |      - relative_cost(1): 0.25000 (avg 0.08333)
INFO     |      - anchor_cost(2): 0.25000 (avg 0.04167)
INFO     |      accepted=True ATb_norm=7.07e-01 cost_prev=0.5000 cost_new=0.3333
INFO     |  step #1: cost=0.3333 lambd=0.0003 inexact_tol=1.0e-02
INFO     |      - relative_cost(1): 0.11112 (avg 0.03704)
INFO     |      - anchor_cost(2): 0.22221 (avg 0.03704)
INFO     |      accepted=True ATb_norm=1.34e-04 cost_prev=0.3333 cost_new=0.3333
INFO     | Terminated @ iteration #2: cost=0.3333 criteria=[1 0 0], term_deltas=1.8e-07,1.2e-04,3.7e-05
Marginal covariance pose0:
[[0.6666667  0.         0.        ]
 [0.         0.7081737  0.11111112]
 [0.         0.1111111  0.6666667 ]]

Marginal covariance pose1:
[[0.6666667  0.         0.        ]
 [0.         2.4351852  1.0555556 ]
 [0.         1.0555556  0.63218397]]

Cross-covariance pose0-pose1:
[[0.33333334 0.         0.        ]
 [0.         0.6018518  0.15900384]
 [0.         0.6111111  0.33333337]]

Estimator options#

jaxls provides different methods for computing covariances:

  • Conjugate Gradient (default): GPU-friendly iterative solver. Fast when variables are weakly correlated.

  • Dense Cholesky: Direct solver using dense matrices. Suitable for small to medium problems.

  • CHOLMOD: Sparse direct solver (requires sksparse). Caches factorization for efficient repeated queries on large sparse problems.

# Default: Conjugate Gradient.
estimator_cg = problem.make_covariance_estimator(solution)

# Dense Cholesky.
estimator_dense = problem.make_covariance_estimator(
    solution,
    method=jaxls.LinearSolverCovarianceEstimatorConfig(linear_solver="dense_cholesky"),
)

# Compare results.
cov_cg = estimator_cg.covariance(pose_var)
cov_dense = estimator_dense.covariance(pose_var)

print("CG and Dense Cholesky give the same result:")
print(f"  Max difference: {jnp.max(jnp.abs(cov_cg - cov_dense)):.2e}")
CG and Dense Cholesky give the same result:
  Max difference: 2.33e-10

Summary#

Covariance estimation in jaxls:

  1. After solving, create a covariance estimator with problem.make_covariance_estimator(solution).

  2. Extract blocks using estimator.covariance(var) for marginal covariance or estimator.covariance(var0, var1) for cross-covariance.

  3. EKF connection: The EKF measurement update equals one Gauss-Newton iteration plus \((J^T J)^{-1}\) covariance.

  4. Multiple methods: CG (default, GPU-friendly), dense Cholesky, or CHOLMOD for large sparse problems.

For more details, see CovarianceEstimator, LinearSolverCovarianceEstimatorConfig, and make_covariance_estimator().