Covariance estimation#
Computing uncertainty estimates from nonlinear least squares solutions in jaxls.
After solving a nonlinear least squares problem, we often want to quantify uncertainty in the estimated variables. The covariance matrix \((J^T J)^{-1}\) provides this information, representing the tangent-space uncertainty of each variable.
Features used:
make_covariance_estimator()for creating estimatorsCovarianceEstimatorfor extracting covariance blocksLinearSolverCovarianceEstimatorConfigfor configuration
import sys
from loguru import logger
logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import jax
import jax.numpy as jnp
import jaxlie
import jaxls
import numpy as np
Background#
For a nonlinear least squares problem that minimizes \(\sum_i \|r_i(x)\|^2\), the covariance of the solution at a local minimum is approximated by:
where:
\(J\) is the Jacobian of residuals at the solution
\(\sigma^2\) is the residual variance, estimated as \(\|r\|^2 / (m - n)\) where \(m\) is the number of residuals and \(n\) is the number of parameters
For manifold variables (like SE3 poses), this covariance lives in the tangent space.
Simple example: 2D localization#
Let’s start with a simple problem: estimating a 2D pose from noisy measurements.
# Ground truth pose.
true_pose = jaxlie.SE2.from_xy_theta(1.0, 2.0, jnp.pi / 4)
# Generate noisy measurements.
np.random.seed(42)
n_measurements = 10
measurement_noise_std = 0.1
# Noisy position measurements (x, y).
measured_positions = (
true_pose.translation()
+ jnp.array(np.random.randn(n_measurements, 2)) * measurement_noise_std
)
# Noisy orientation measurements (as unit vectors).
true_angle = jnp.arctan2(
true_pose.rotation().as_matrix()[1, 0], true_pose.rotation().as_matrix()[0, 0]
)
measured_angles = true_angle + np.random.randn(n_measurements) * measurement_noise_std
print(
f"True pose: x={true_pose.translation()[0]:.3f}, y={true_pose.translation()[1]:.3f}, theta={float(true_angle):.3f}"
)
True pose: x=1.000, y=2.000, theta=0.785
# Define cost functions.
@jaxls.Cost.factory
def position_cost(
vals: jaxls.VarValues, pose: jaxls.SE2Var, measured_xy: jax.Array
) -> jax.Array:
"""Residual: difference between pose position and measurement."""
return vals[pose].translation() - measured_xy
@jaxls.Cost.factory
def orientation_cost(
vals: jaxls.VarValues, pose: jaxls.SE2Var, measured_theta: jax.Array
) -> jax.Array:
"""Residual: difference in orientation (as log of rotation)."""
R = vals[pose].rotation()
R_measured = jaxlie.SO2.from_radians(measured_theta)
return (R @ R_measured.inverse()).log()
# Build and solve the problem.
pose_var = jaxls.SE2Var(id=0)
# Batch costs.
costs = [
position_cost(
jaxls.SE2Var(id=jnp.zeros(n_measurements, dtype=jnp.int32)),
measured_positions,
),
orientation_cost(
jaxls.SE2Var(id=jnp.zeros(n_measurements, dtype=jnp.int32)),
measured_angles,
),
]
initial_vals = jaxls.VarValues.make([pose_var.with_value(jaxlie.SE2.identity())])
problem = jaxls.LeastSquaresProblem(costs, [pose_var]).analyze()
solution = problem.solve(initial_vals)
estimated_pose = solution[pose_var]
est_angle = jnp.arctan2(
estimated_pose.rotation().as_matrix()[1, 0],
estimated_pose.rotation().as_matrix()[0, 0],
)
print(
f"Estimated pose: x={estimated_pose.translation()[0]:.3f}, y={estimated_pose.translation()[1]:.3f}, theta={float(est_angle):.3f}"
)
INFO | Building optimization problem with 20 terms and 1 variables: 20 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO | Vectorizing group with 10 costs, 1 variables each: orientation_cost
INFO | Vectorizing group with 10 costs, 1 variables each: position_cost
INFO | step #0: cost=55.0646 lambd=0.0005 inexact_tol=1.0e-02
INFO | - orientation_cost(10): 5.88457 (avg 0.58846)
INFO | - position_cost(10): 49.18004 (avg 2.45900)
INFO | accepted=True ATb_norm=2.90e+01 cost_prev=55.0646 cost_new=7.1428
INFO | step #1: cost=7.1428 lambd=0.0003 inexact_tol=1.0e-02
INFO | - orientation_cost(10): 0.05962 (avg 0.00596)
INFO | - position_cost(10): 7.08317 (avg 0.35416)
INFO | accepted=True ATb_norm=1.03e+01 cost_prev=7.1428 cost_new=0.2347
INFO | step #2: cost=0.2347 lambd=0.0001 inexact_tol=1.0e-02
INFO | - orientation_cost(10): 0.05962 (avg 0.00596)
INFO | - position_cost(10): 0.17508 (avg 0.00875)
INFO | accepted=False ATb_norm=1.75e-04 cost_prev=0.2347 cost_new=0.2347
INFO | Terminated @ iteration #3: cost=0.2347 criteria=[1 0 0], term_deltas=0.0e+00,1.7e-04,5.8e-06
Estimated pose: x=0.982, y=1.984, theta=0.763
Computing covariance#
Use make_covariance_estimator() to create a covariance estimator, then extract blocks with covariance():
# Create covariance estimator (default: conjugate gradient solver).
estimator = problem.make_covariance_estimator(solution)
# Get the 3x3 covariance matrix for the SE2 pose (in tangent space).
# Tangent space is [vx, vy, omega] for SE2.
cov = estimator.covariance(pose_var)
print("Covariance matrix (tangent space):")
print(cov)
print(
f"\nStandard deviations: x={jnp.sqrt(cov[0, 0]):.4f}, y={jnp.sqrt(cov[1, 1]):.4f}, theta={jnp.sqrt(cov[2, 2]):.4f}"
)
Covariance matrix (tangent space):
[[ 8.6924818e-04 -5.4531440e-12 0.0000000e+00]
[-5.4531418e-12 8.6924795e-04 0.0000000e+00]
[ 0.0000000e+00 0.0000000e+00 8.6924812e-04]]
Standard deviations: x=0.0295, y=0.0295, theta=0.0295
EKF measurement update as Gauss-Newton#
The measurement update step of an Extended Kalman Filter (EKF) is equivalent to a single Gauss-Newton iteration on a least squares problem with:
A prior cost (encoding the predicted state and covariance)
Measurement costs
The posterior covariance \((J^T J)^{-1}\) matches the EKF covariance update. This equivalence is exact for linear systems; for nonlinear systems, iterating to convergence (as in Iterated EKF) can improve accuracy.
# Simulate EKF-style sequential updates.
# Prior: initial pose estimate with uncertainty.
prior_pose = jaxlie.SE2.from_xy_theta(0.5, 1.5, jnp.pi / 6)
prior_cov = jnp.diag(jnp.array([0.5**2, 0.5**2, 0.3**2])) # Prior uncertainty.
# For EKF equivalence, use Euclidean parameterization [x, y, theta].
class PoseVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(3)):
"""Pose as [x, y, theta] vector."""
@jaxls.Cost.factory
def euclidean_prior_cost(
vals: jaxls.VarValues,
pose: PoseVar,
prior_mean: jax.Array,
info_sqrt: jax.Array,
) -> jax.Array:
"""Prior cost in Euclidean coordinates."""
error = vals[pose] - prior_mean
return info_sqrt @ error
# Information matrix is inverse of covariance.
prior_info = jnp.linalg.inv(prior_cov)
prior_info_sqrt = jnp.linalg.cholesky(prior_info)
prior_mean = jnp.array([0.5, 1.5, jnp.pi / 6])
# Take a single measurement.
single_measurement_xy = measured_positions[0]
single_measurement_theta = measured_angles[0]
# Measurement noise covariance.
# Using higher noise than the data was generated with, to show the prior's influence.
meas_noise_std = 0.3
meas_info_sqrt = 1.0 / meas_noise_std # Weight = 1/sigma for residuals.
@jaxls.Cost.factory
def euclidean_position_cost(
vals: jaxls.VarValues,
pose: PoseVar,
measured_xy: jax.Array,
weight: float,
) -> jax.Array:
"""Weighted position residual (Euclidean)."""
return weight * (vals[pose][:2] - measured_xy)
@jaxls.Cost.factory
def euclidean_orientation_cost(
vals: jaxls.VarValues,
pose: PoseVar,
measured_theta: jax.Array,
weight: float,
) -> jax.Array:
"""Weighted orientation residual (Euclidean angle difference)."""
# Simple angle difference (works for small angles).
angle_diff = vals[pose][2] - measured_theta
# Wrap to [-pi, pi].
angle_diff = jnp.arctan2(jnp.sin(angle_diff), jnp.cos(angle_diff))
return weight * angle_diff
# EKF update = single Gauss-Newton step + covariance.
ekf_pose_var = PoseVar(id=0)
ekf_costs = [
euclidean_prior_cost(ekf_pose_var, prior_mean, prior_info_sqrt),
euclidean_position_cost(ekf_pose_var, single_measurement_xy, meas_info_sqrt),
euclidean_orientation_cost(ekf_pose_var, single_measurement_theta, meas_info_sqrt),
]
ekf_initial = jaxls.VarValues.make([ekf_pose_var.with_value(prior_mean)])
ekf_problem = jaxls.LeastSquaresProblem(ekf_costs, [ekf_pose_var]).analyze()
# Single iteration (like EKF).
ekf_solution = ekf_problem.solve(
ekf_initial,
trust_region=None, # Pure Gauss-Newton (no trust region).
termination=jaxls.TerminationConfig(max_iterations=1),
verbose=False,
)
# Get posterior covariance.
ekf_estimator = ekf_problem.make_covariance_estimator(
ekf_solution, scale_by_residual_variance=False
)
posterior_cov = ekf_estimator.covariance(ekf_pose_var)
ekf_pose_vec = ekf_solution[ekf_pose_var]
print("EKF-style update (1 Gauss-Newton iteration):")
print(f" Prior: x={prior_mean[0]:.3f}, y={prior_mean[1]:.3f}")
print(
f" Measurement: x={single_measurement_xy[0]:.3f}, y={single_measurement_xy[1]:.3f}"
)
print(f" Posterior: x={ekf_pose_vec[0]:.3f}, y={ekf_pose_vec[1]:.3f}")
print(
f"\nPrior uncertainty (std): x={jnp.sqrt(prior_cov[0, 0]):.3f}, y={jnp.sqrt(prior_cov[1, 1]):.3f}"
)
print(
f"Posterior uncertainty (std): x={jnp.sqrt(posterior_cov[0, 0]):.3f}, y={jnp.sqrt(posterior_cov[1, 1]):.3f}"
)
INFO | Building optimization problem with 3 terms and 1 variables: 3 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO | Vectorizing group with 1 costs, 1 variables each: euclidean_prior_cost
INFO | Vectorizing group with 1 costs, 1 variables each: euclidean_position_cost
INFO | Vectorizing group with 1 costs, 1 variables each: euclidean_orientation_cost
EKF-style update (1 Gauss-Newton iteration):
Prior: x=0.500, y=1.500
Measurement: x=1.050, y=1.986
Posterior: x=0.904, y=1.857
Prior uncertainty (std): x=0.500, y=0.500
Posterior uncertainty (std): x=0.257, y=0.257
Cross-covariance between variables#
For problems with multiple variables, we can compute cross-covariance blocks:
# Two-pose problem with relative constraint.
pose0 = jaxls.SE2Var(id=0)
pose1 = jaxls.SE2Var(id=1)
@jaxls.Cost.factory
def anchor_cost(
vals: jaxls.VarValues, pose: jaxls.SE2Var, target: jaxlie.SE2
) -> jax.Array:
"""Anchor pose to a target."""
return (vals[pose] @ target.inverse()).log()
@jaxls.Cost.factory
def relative_cost(
vals: jaxls.VarValues,
pose_a: jaxls.SE2Var,
pose_b: jaxls.SE2Var,
delta: jaxlie.SE2,
) -> jax.Array:
"""Relative pose constraint."""
T_a = vals[pose_a]
T_b = vals[pose_b]
measured = T_a.inverse() @ T_b
return (measured @ delta.inverse()).log()
two_pose_costs = [
anchor_cost(pose0, jaxlie.SE2.from_xy_theta(0.0, 0.0, 0.0)),
anchor_cost(pose1, jaxlie.SE2.from_xy_theta(2.0, 0.0, 0.0)),
relative_cost(pose0, pose1, jaxlie.SE2.from_xy_theta(1.0, 0.0, 0.0)),
]
two_pose_initial = jaxls.VarValues.make(
[
pose0.with_value(jaxlie.SE2.identity()),
pose1.with_value(jaxlie.SE2.from_xy_theta(1.5, 0.0, 0.0)),
]
)
two_pose_problem = jaxls.LeastSquaresProblem(two_pose_costs, [pose0, pose1]).analyze()
two_pose_solution = two_pose_problem.solve(two_pose_initial)
# Get covariance estimator.
two_pose_estimator = two_pose_problem.make_covariance_estimator(
two_pose_solution, scale_by_residual_variance=False
)
# Marginal covariances.
cov_00 = two_pose_estimator.covariance(pose0)
cov_11 = two_pose_estimator.covariance(pose1)
# Cross-covariance.
cov_01 = two_pose_estimator.covariance(pose0, pose1)
print("Marginal covariance pose0:")
print(cov_00)
print("\nMarginal covariance pose1:")
print(cov_11)
print("\nCross-covariance pose0-pose1:")
print(cov_01)
INFO | Building optimization problem with 3 terms and 2 variables: 3 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO | Vectorizing group with 1 costs, 2 variables each: relative_cost
INFO | Vectorizing group with 2 costs, 1 variables each: anchor_cost
INFO | step #0: cost=0.5000 lambd=0.0005 inexact_tol=1.0e-02
INFO | - relative_cost(1): 0.25000 (avg 0.08333)
INFO | - anchor_cost(2): 0.25000 (avg 0.04167)
INFO | accepted=True ATb_norm=7.07e-01 cost_prev=0.5000 cost_new=0.3333
INFO | step #1: cost=0.3333 lambd=0.0003 inexact_tol=1.0e-02
INFO | - relative_cost(1): 0.11112 (avg 0.03704)
INFO | - anchor_cost(2): 0.22221 (avg 0.03704)
INFO | accepted=True ATb_norm=1.34e-04 cost_prev=0.3333 cost_new=0.3333
INFO | Terminated @ iteration #2: cost=0.3333 criteria=[1 0 0], term_deltas=1.8e-07,1.2e-04,3.7e-05
Marginal covariance pose0:
[[0.6666667 0. 0. ]
[0. 0.7081737 0.11111112]
[0. 0.1111111 0.6666667 ]]
Marginal covariance pose1:
[[0.6666667 0. 0. ]
[0. 2.4351852 1.0555556 ]
[0. 1.0555556 0.63218397]]
Cross-covariance pose0-pose1:
[[0.33333334 0. 0. ]
[0. 0.6018518 0.15900384]
[0. 0.6111111 0.33333337]]
Estimator options#
jaxls provides different methods for computing covariances:
Conjugate Gradient (default): GPU-friendly iterative solver. Fast when variables are weakly correlated.
Dense Cholesky: Direct solver using dense matrices. Suitable for small to medium problems.
CHOLMOD: Sparse direct solver (requires
sksparse). Caches factorization for efficient repeated queries on large sparse problems.
# Default: Conjugate Gradient.
estimator_cg = problem.make_covariance_estimator(solution)
# Dense Cholesky.
estimator_dense = problem.make_covariance_estimator(
solution,
method=jaxls.LinearSolverCovarianceEstimatorConfig(linear_solver="dense_cholesky"),
)
# Compare results.
cov_cg = estimator_cg.covariance(pose_var)
cov_dense = estimator_dense.covariance(pose_var)
print("CG and Dense Cholesky give the same result:")
print(f" Max difference: {jnp.max(jnp.abs(cov_cg - cov_dense)):.2e}")
CG and Dense Cholesky give the same result:
Max difference: 2.33e-10
Summary#
Covariance estimation in jaxls:
After solving, create a covariance estimator with
problem.make_covariance_estimator(solution).Extract blocks using
estimator.covariance(var)for marginal covariance orestimator.covariance(var0, var1)for cross-covariance.EKF connection: The EKF measurement update equals one Gauss-Newton iteration plus \((J^T J)^{-1}\) covariance.
Multiple methods: CG (default, GPU-friendly), dense Cholesky, or CHOLMOD for large sparse problems.
For more details, see CovarianceEstimator, LinearSolverCovarianceEstimatorConfig, and make_covariance_estimator().