SE(3) pose graph#
In this notebook, we solve a 3D pose graph optimization problem: estimating robot poses from noisy relative measurements in full 6-DOF.
Extending pose graph optimization to 3D enables SLAM for drones, underwater vehicles, and handheld mapping devices. The core idea remains the same as 2D: relative motion measurements between poses accumulate error, but loop closures that recognize revisited locations provide global constraints.
This example uses the sphere2500 dataset, where poses are arranged on a sphere surface.
Features used:
SE3Varfor SE(3) robot poses@jaxls.Cost.factorywith batched edge constructiong2o dataset with ~2500 poses and loop closures
import pathlib
import jax
import jax.numpy as jnp
import jaxlie
import jaxls
import numpy as np
Loading the dataset#
Parse the g2o file to extract poses and edges. The sphere2500 dataset has poses arranged on a sphere surface with loop closures connecting nearby poses:
@jax.jit
def parse_precision_matrix(components: jax.Array) -> jax.Array:
"""Convert upper triangular components to sqrt precision matrix.
Args:
components: Upper triangular components of the precision matrix (21,)
Returns:
Upper Cholesky factor of the precision matrix (6, 6)
"""
precision = jnp.zeros((6, 6))
triu_indices = jnp.triu_indices(6)
precision = precision.at[triu_indices].set(components)
precision = precision + precision.T - jnp.diag(jnp.diag(precision))
return jnp.linalg.cholesky(precision).T
def parse_g2o_se3(path: pathlib.Path) -> dict:
"""Parse a 3D g2o file (VERTEX_SE3:QUAT and EDGE_SE3:QUAT format).
Args:
path: Path to the g2o file
Returns:
Dictionary with 'poses' (N, 7) array and 'edges' list of tuples
"""
with open(path) as f:
lines = f.readlines()
poses = [] # (x, y, z, qx, qy, qz, qw)
edges = [] # (i, j, x, y, z, qx, qy, qz, qw, precision_components)
for line in lines:
parts = line.strip().split()
if not parts:
continue
if parts[0] == "VERTEX_SE3:QUAT":
# Format: VERTEX_SE3:QUAT id x y z qx qy qz qw.
_, idx, x, y, z, qx, qy, qz, qw = parts
poses.append(
(
float(x),
float(y),
float(z),
float(qx),
float(qy),
float(qz),
float(qw),
)
)
elif parts[0] == "EDGE_SE3:QUAT":
# Format: EDGE_SE3:QUAT i j x y z qx qy qz qw info_upper_tri(21)
_, i, j = parts[:3]
numerical = list(map(float, parts[3:]))
x, y, z = numerical[0:3]
qx, qy, qz, qw = numerical[3:7]
precision_comps = np.array(numerical[7:])
edges.append((int(i), int(j), x, y, z, qx, qy, qz, qw, precision_comps))
return {"poses": np.array(poses), "edges": edges}
# Load the sphere2500 dataset.
g2o_path = pathlib.Path("./data/sphere2500.g2o")
data = parse_g2o_se3(g2o_path)
num_poses = len(data["poses"])
num_edges = len(data["edges"])
# Count odometry vs loop closure edges.
odometry_edges = [(i, j) for i, j, *_ in data["edges"] if j == i + 1]
loop_closure_edges = [(i, j) for i, j, *_ in data["edges"] if j != i + 1]
print(f"Poses: {num_poses}")
print(
f"Edges: {num_edges} ({len(odometry_edges)} odometry, {len(loop_closure_edges)} loop closures)"
)
Poses: 2500
Edges: 4949 (2499 odometry, 2450 loop closures)
Variables and costs#
Use SE3Var for poses on SE(3). Create batched costs for efficient optimization:
# Create batched pose variables.
pose_vars = jaxls.SE3Var(id=jnp.arange(num_poses))
@jaxls.Cost.factory
def between_cost(
vals: jaxls.VarValues,
var_a: jaxls.SE3Var,
var_b: jaxls.SE3Var,
measured: jaxlie.SE3,
sqrt_precision: jax.Array,
) -> jax.Array:
"""Cost for relative pose measurement between two poses."""
T_a = vals[var_a]
T_b = vals[var_b]
# Error: measured^{-1} @ (T_a^{-1} @ T_b)
residual = (measured.inverse() @ (T_a.inverse() @ T_b)).log()
return sqrt_precision @ residual
@jaxls.Cost.factory(kind="constraint_eq_zero")
def anchor_cost(
vals: jaxls.VarValues,
var: jaxls.SE3Var,
target: jaxlie.SE3,
) -> jax.Array:
"""Anchor the first pose to prevent gauge freedom."""
return (vals[var].inverse() @ target).log()
# Build edge arrays for batched cost construction.
edge_i = jnp.array([e[0] for e in data["edges"]])
edge_j = jnp.array([e[1] for e in data["edges"]])
# Measured relative poses (g2o uses xyzw quaternion order)
measured_poses = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.from_quaternion_xyzw(
jnp.array([[e[5], e[6], e[7], e[8]] for e in data["edges"]])
),
translation=jnp.array([[e[2], e[3], e[4]] for e in data["edges"]]),
)
# Sqrt precision matrices.
precision_comps = jnp.array([e[9] for e in data["edges"]])
sqrt_precisions = jax.vmap(parse_precision_matrix)(precision_comps)
print(f"Batched edge arrays: {edge_i.shape[0]} edges")
Batched edge arrays: 4949 edges
Solving#
# Initial poses from g2o file (g2o uses xyzw quaternion order)
initial_poses = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.from_quaternion_xyzw(
jnp.array(data["poses"][:, 3:7]) # qx, qy, qz, qw
),
translation=jnp.array(data["poses"][:, 0:3]), # x, y, z
)
# Create costs using batched construction.
costs: list[jaxls.Cost] = [
# All between factors in one batched call.
between_cost(
jaxls.SE3Var(id=edge_i),
jaxls.SE3Var(id=edge_j),
measured_poses,
sqrt_precisions,
),
# Anchor first pose.
anchor_cost(
jaxls.SE3Var(id=0),
jaxlie.SE3(wxyz_xyz=initial_poses.wxyz_xyz[0]),
),
]
initial_vals = jaxls.VarValues.make([pose_vars.with_value(initial_poses)])
# Build and analyze problem.
problem = jaxls.LeastSquaresProblem(costs, [pose_vars])
# Visualize the problem structure structure.
problem.show()
# Analyze the problem.
problem = problem.analyze()
INFO | Building optimization problem with 4950 terms and 2500 variables: 4949 costs, 1 eq_zero, 0 leq_zero, 0 geq_zero
INFO | Vectorizing group with 4949 costs, 2 variables each: between_cost
INFO | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_cost
# Solve with Gauss-Newton.
solution = problem.solve(initial_vals, trust_region=None)
INFO | Augmented Lagrangian: initial snorm=0.0000e+00, csupn=0.0000e+00, max_rho=1.0000e+07, constraint_dim=6
INFO | step #0: cost=2611316.0000 lambd=0.0000 inexact_tol=1.0e-02
INFO | - between_cost(4949): 2611316.00000 (avg 87.94086)
INFO | - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=1.65e+04 cost_prev=2611315.2500 cost_new=327240.5000
INFO | step #1: cost=327240.4688 lambd=0.0000 inexact_tol=1.0e-02
INFO | - between_cost(4949): 327240.46875 (avg 11.02042)
INFO | - augmented_anchor_cost(1): 0.00734 (avg 0.00122)
INFO | accepted=True ATb_norm=5.87e+03 cost_prev=327240.5000 cost_new=13114.3945
INFO | step #2: cost=13114.3926 lambd=0.0000 inexact_tol=1.0e-02
INFO | - between_cost(4949): 13114.39258 (avg 0.44165)
INFO | - augmented_anchor_cost(1): 0.00139 (avg 0.00023)
INFO | accepted=True ATb_norm=1.06e+03 cost_prev=13114.3945 cost_new=1489.4092
INFO | step #3: cost=1489.4070 lambd=0.0000 inexact_tol=1.0e-02
INFO | - between_cost(4949): 1489.40698 (avg 0.05016)
INFO | - augmented_anchor_cost(1): 0.00223 (avg 0.00037)
INFO | accepted=True ATb_norm=8.04e+01 cost_prev=1489.4092 cost_new=1376.8256
INFO | step #4: cost=1376.8253 lambd=0.0000 inexact_tol=5.2e-03
INFO | - between_cost(4949): 1376.82532 (avg 0.04637)
INFO | - augmented_anchor_cost(1): 0.00012 (avg 0.00002)
INFO | accepted=True ATb_norm=4.46e+01 cost_prev=1376.8256 cost_new=1351.5841
INFO | step #5: cost=1351.5840 lambd=0.0000 inexact_tol=5.2e-03
INFO | - between_cost(4949): 1351.58398 (avg 0.04552)
INFO | - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=2.23e+00 cost_prev=1351.5841 cost_new=1351.4075
INFO | step #6: cost=1351.4073 lambd=0.0000 inexact_tol=2.2e-03
INFO | - between_cost(4949): 1351.40735 (avg 0.04551)
INFO | - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=1.08e-01 cost_prev=1351.4075 cost_new=1351.4023
INFO | AL update: snorm=2.1340e-08, csupn=2.1340e-08, max_rho=1.0000e+07
INFO | Terminated @ iteration #7: cost=1351.4023 criteria=[1 0 0], term_deltas=3.7e-06,9.3e-03,1.5e-04
Visualization#
Compare the initial trajectory with the optimized result. The poses form a sphere surface with loop closures connecting nearby points:
initial_xyz = np.array(initial_vals[pose_vars].translation())
optimized_xyz = np.array(solution[pose_vars].translation())
╭────── viser (listening *:8082) ───────╮ │ ╷ │ │ HTTP │ http://localhost:8082 │ │ Websocket │ ws://localhost:8082 │ │ ╵ │ ╰───────────────────────────────────────╯
The optimization refines the noisy initial pose estimates using loop closure constraints. The resulting trajectory forms a clean sphere surface, demonstrating SE(3) pose graph optimization in 3D.
For more on Lie group variables, see jaxls.SE3Var and jaxls.SE2Var.