SE(2) pose graph#
In this notebook, we solve a 2D pose graph optimization problem: estimating robot poses from noisy relative measurements.
Inputs: Relative pose measurements (odometry + loop closures) from g2o file
Outputs: Globally consistent robot trajectory
Features used:
SE2Varfor SE(2) robot poses@jaxls.Cost.factorywith batched edge constructionReal g2o dataset with ~3500 poses and loop closures
import pathlib
import jax
import jax.numpy as jnp
import jaxlie
import jaxls
import numpy as np
Loading the dataset#
Parse the g2o file to extract poses and edges. Edges include both odometry (consecutive poses) and loop closures (revisited locations):
@jax.jit
def parse_precision_matrix(components: jax.Array) -> jax.Array:
"""Convert upper triangular components to sqrt precision matrix.
Args:
components: Upper triangular components of the precision matrix (6,)
Returns:
Upper Cholesky factor of the precision matrix (3, 3)
"""
precision = jnp.zeros((3, 3))
triu_indices = jnp.triu_indices(3)
precision = precision.at[triu_indices].set(components)
precision = precision + precision.T - jnp.diag(jnp.diag(precision))
return jnp.linalg.cholesky(precision).T
def parse_g2o(path: pathlib.Path) -> dict:
"""Parse a 2D g2o file into poses and edges.
Args:
path: Path to the g2o file
Returns:
Dictionary with 'poses' (N, 3) array and 'edges' list of tuples
"""
with open(path) as f:
lines = f.readlines()
poses = [] # (x, y, theta)
edges = [] # (i, j, dx, dy, dtheta, precision_components)
for line in lines:
parts = line.strip().split()
if not parts:
continue
if parts[0] == "VERTEX_SE2":
_, idx, x, y, theta = parts
poses.append((float(x), float(y), float(theta)))
elif parts[0] == "EDGE_SE2":
_, i, j = parts[:3]
dx, dy, dtheta = map(float, parts[3:6])
precision_comps = list(map(float, parts[6:]))
edges.append((int(i), int(j), dx, dy, dtheta, np.array(precision_comps)))
return {"poses": np.array(poses), "edges": edges}
# Load the Manhattan 3500 dataset.
g2o_path = pathlib.Path("./data/input_M3500_g2o.g2o")
data = parse_g2o(g2o_path)
num_poses = len(data["poses"])
num_edges = len(data["edges"])
# Count odometry vs loop closure edges.
odometry_edges = [(i, j) for i, j, *_ in data["edges"] if j == i + 1]
loop_closure_edges = [(i, j) for i, j, *_ in data["edges"] if j != i + 1]
print(f"Poses: {num_poses}")
print(
f"Edges: {num_edges} ({len(odometry_edges)} odometry, {len(loop_closure_edges)} loop closures)"
)
Poses: 3500
Edges: 5453 (3499 odometry, 1954 loop closures)
Variables and costs#
Use SE2Var for poses on SE(2). The between factor measures the relative pose between two nodes, weighted by a precision matrix.
The sqrt_precision matrix is the upper Cholesky factor of the information (precision) matrix \(\Lambda = \Sigma^{-1}\), where \(\Sigma\) is the measurement covariance. By multiplying the residual by this factor, we get a whitened residual whose squared norm equals the Mahalanobis distance: \(r^T \Lambda \, r = \|L^T r\|^2\).
# Create batched pose variables.
pose_vars = jaxls.SE2Var(id=jnp.arange(num_poses))
@jaxls.Cost.factory
def between_cost(
vals: jaxls.VarValues,
var_a: jaxls.SE2Var,
var_b: jaxls.SE2Var,
measured: jaxlie.SE2,
sqrt_precision: jax.Array,
) -> jax.Array:
"""Cost for relative pose measurement between two poses."""
T_a = vals[var_a]
T_b = vals[var_b]
# Error: measured^{-1} @ (T_a^{-1} @ T_b)
residual = (measured.inverse() @ (T_a.inverse() @ T_b)).log()
return sqrt_precision @ residual
@jaxls.Cost.factory(kind="constraint_eq_zero")
def anchor_cost(
vals: jaxls.VarValues,
var: jaxls.SE2Var,
target: jaxlie.SE2,
) -> jax.Array:
"""Anchor the first pose to prevent gauge freedom."""
return (vals[var].inverse() @ target).log()
# Build edge arrays for batched cost construction.
edge_i = jnp.array([e[0] for e in data["edges"]])
edge_j = jnp.array([e[1] for e in data["edges"]])
# Measured relative poses.
measured_poses = jaxlie.SE2.from_xy_theta(
jnp.array([e[2] for e in data["edges"]]),
jnp.array([e[3] for e in data["edges"]]),
jnp.array([e[4] for e in data["edges"]]),
)
# Sqrt precision matrices.
precision_comps = jnp.array([e[5] for e in data["edges"]])
sqrt_precisions = jax.vmap(parse_precision_matrix)(precision_comps)
print(f"Batched edge arrays: {edge_i.shape[0]} edges")
Batched edge arrays: 5453 edges
Solving#
Problem structure#
The problem structure connects pose variables through between factors (odometry and loop closures). With ~3500 poses and ~5600 edges, the visualization automatically limits nodes for performance:
# Create costs using batched construction.
costs: list[jaxls.Cost] = [
# All between factors in one batched call.
between_cost(
jaxls.SE2Var(id=edge_i),
jaxls.SE2Var(id=edge_j),
measured_poses,
sqrt_precisions,
),
# Anchor first pose.
anchor_cost(
jaxls.SE2Var(id=0),
jaxlie.SE2.from_xy_theta(
data["poses"][0, 0], data["poses"][0, 1], data["poses"][0, 2]
),
),
]
# Visualize the problem structure structure.
problem = jaxls.LeastSquaresProblem(costs, [pose_vars])
problem.show()
# Initial values from g2o file (odometry integration).
initial_poses = jaxlie.SE2.from_xy_theta(
jnp.array(data["poses"][:, 0]),
jnp.array(data["poses"][:, 1]),
jnp.array(data["poses"][:, 2]),
)
initial_vals = jaxls.VarValues.make([pose_vars.with_value(initial_poses)])
# Analyze the problem.
problem = problem.analyze()
INFO | Building optimization problem with 5454 terms and 3500 variables: 5453 costs, 1 eq_zero, 0 leq_zero, 0 geq_zero
INFO | Vectorizing group with 5453 costs, 2 variables each: between_cost
INFO | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_cost
# Solve with Gauss-Newton (no trust region needed for this well-conditioned problem)
solution = problem.solve(initial_vals, trust_region=None)
INFO | Augmented Lagrangian: initial snorm=0.0000e+00, csupn=0.0000e+00, max_rho=1.0000e+07, constraint_dim=3
INFO | step #0: cost=2634707.7500 lambd=0.0000 inexact_tol=1.0e-02
INFO | - between_cost(5453): 2634707.75000 (avg 161.05556)
INFO | - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=3.32e+04 cost_prev=2634707.7500 cost_new=57105.4102
INFO | step #1: cost=57105.3906 lambd=0.0000 inexact_tol=1.0e-02
INFO | - between_cost(5453): 57105.39062 (avg 3.49076)
INFO | - augmented_anchor_cost(1): 0.01625 (avg 0.00542)
INFO | accepted=True ATb_norm=3.59e+03 cost_prev=57105.4102 cost_new=11499.3945
INFO | step #2: cost=11499.3926 lambd=0.0000 inexact_tol=1.0e-02
INFO | - between_cost(5453): 11499.39258 (avg 0.70294)
INFO | - augmented_anchor_cost(1): 0.00040 (avg 0.00013)
INFO | accepted=True ATb_norm=1.24e+03 cost_prev=11499.3945 cost_new=3325.4399
INFO | step #3: cost=3325.4395 lambd=0.0000 inexact_tol=1.0e-02
INFO | - between_cost(5453): 3325.43945 (avg 0.20328)
INFO | - augmented_anchor_cost(1): 0.00016 (avg 0.00005)
INFO | accepted=True ATb_norm=4.12e+02 cost_prev=3325.4399 cost_new=11917.8184
INFO | step #4: cost=11917.8184 lambd=0.0000 inexact_tol=1.0e-02
INFO | - between_cost(5453): 11917.81836 (avg 0.72852)
INFO | - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=1.06e+03 cost_prev=11917.8184 cost_new=2345.9548
INFO | step #5: cost=2345.9546 lambd=0.0000 inexact_tol=1.0e-02
INFO | - between_cost(5453): 2345.95459 (avg 0.14340)
INFO | - augmented_anchor_cost(1): 0.00001 (avg 0.00000)
INFO | accepted=True ATb_norm=4.09e+02 cost_prev=2345.9548 cost_new=705.6475
INFO | step #6: cost=705.6476 lambd=0.0000 inexact_tol=1.0e-02
INFO | - between_cost(5453): 705.64758 (avg 0.04314)
INFO | - augmented_anchor_cost(1): 0.00002 (avg 0.00001)
INFO | accepted=True ATb_norm=1.58e+02 cost_prev=705.6475 cost_new=1310.0546
INFO | step #7: cost=1310.0547 lambd=0.0000 inexact_tol=1.0e-02
INFO | - between_cost(5453): 1310.05469 (avg 0.08008)
INFO | - augmented_anchor_cost(1): 0.00001 (avg 0.00000)
INFO | accepted=True ATb_norm=2.80e+02 cost_prev=1310.0546 cost_new=181.4150
INFO | step #8: cost=181.4150 lambd=0.0000 inexact_tol=1.0e-02
INFO | - between_cost(5453): 181.41498 (avg 0.01109)
INFO | - augmented_anchor_cost(1): 0.00001 (avg 0.00000)
INFO | accepted=True ATb_norm=2.18e+01 cost_prev=181.4150 cost_new=381.8472
INFO | step #9: cost=381.8471 lambd=0.0000 inexact_tol=5.5e-03
INFO | - between_cost(5453): 381.84714 (avg 0.02334)
INFO | - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=1.28e+02 cost_prev=381.8472 cost_new=139.9707
INFO | step #10: cost=139.9707 lambd=0.0000 inexact_tol=5.5e-03
INFO | - between_cost(5453): 139.97066 (avg 0.00856)
INFO | - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=3.30e+00 cost_prev=139.9707 cost_new=139.6545
INFO | step #11: cost=139.6545 lambd=0.0000 inexact_tol=6.0e-04
INFO | - between_cost(5453): 139.65454 (avg 0.00854)
INFO | - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=9.20e+00 cost_prev=139.6545 cost_new=137.9442
INFO | step #12: cost=137.9442 lambd=0.0000 inexact_tol=6.0e-04
INFO | - between_cost(5453): 137.94420 (avg 0.00843)
INFO | - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=7.06e-01 cost_prev=137.9442 cost_new=137.9200
INFO | step #13: cost=137.9200 lambd=0.0000 inexact_tol=6.0e-04
INFO | - between_cost(5453): 137.91998 (avg 0.00843)
INFO | - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=5.22e-01 cost_prev=137.9200 cost_new=137.9158
INFO | step #14: cost=137.9158 lambd=0.0000 inexact_tol=6.0e-04
INFO | - between_cost(5453): 137.91579 (avg 0.00843)
INFO | - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=5.64e-01 cost_prev=137.9158 cost_new=137.9130
INFO | step #15: cost=137.9130 lambd=0.0000 inexact_tol=6.0e-04
INFO | - between_cost(5453): 137.91296 (avg 0.00843)
INFO | - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=5.52e-01 cost_prev=137.9130 cost_new=137.9128
INFO | AL update: snorm=4.8128e-09, csupn=4.8128e-09, max_rho=1.0000e+07
INFO | Terminated @ iteration #16: cost=137.9128 criteria=[1 0 0], term_deltas=8.9e-07,1.3e-01,2.2e-05
Visualization#
Compare the initial odometry-only trajectory with the optimized result. Loop closures correct drift accumulated from odometry:
initial_xy = np.array(initial_vals[pose_vars].translation())
optimized_xy = np.array(solution[pose_vars].translation())
The optimization corrects drift accumulated from odometry-only integration. Loop closures (shown in red) constrain revisited locations to be consistent, resulting in a globally coherent map.
For more on Lie group variables, see jaxls.SE2Var and jaxls.SE3Var.