SE(3) pose graph#

In this notebook, we solve a 3D pose graph optimization problem: estimating robot poses from noisy relative measurements in full 6-DOF.

Extending pose graph optimization to 3D enables SLAM for drones, underwater vehicles, and handheld mapping devices. The core idea remains the same as 2D: relative motion measurements between poses accumulate error, but loop closures that recognize revisited locations provide global constraints.

This example uses the sphere2500 dataset, where poses are arranged on a sphere surface.

Features used:

  • SE3Var for SE(3) robot poses

  • @jaxls.Cost.factory with batched edge construction

  • g2o dataset with ~2500 poses and loop closures

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import pathlib

import jax
import jax.numpy as jnp
import jaxlie
import jaxls
import numpy as np

Loading the dataset#

Parse the g2o file to extract poses and edges. The sphere2500 dataset has poses arranged on a sphere surface with loop closures connecting nearby poses:

@jax.jit
def parse_precision_matrix(components: jax.Array) -> jax.Array:
    """Convert upper triangular components to sqrt precision matrix.

    Args:
        components: Upper triangular components of the precision matrix (21,)

    Returns:
        Upper Cholesky factor of the precision matrix (6, 6)
    """
    precision = jnp.zeros((6, 6))
    triu_indices = jnp.triu_indices(6)
    precision = precision.at[triu_indices].set(components)
    precision = precision + precision.T - jnp.diag(jnp.diag(precision))
    return jnp.linalg.cholesky(precision).T


def parse_g2o_se3(path: pathlib.Path) -> dict:
    """Parse a 3D g2o file (VERTEX_SE3:QUAT and EDGE_SE3:QUAT format).

    Args:
        path: Path to the g2o file

    Returns:
        Dictionary with 'poses' (N, 7) array and 'edges' list of tuples
    """
    with open(path) as f:
        lines = f.readlines()

    poses = []  # (x, y, z, qx, qy, qz, qw)
    edges = []  # (i, j, x, y, z, qx, qy, qz, qw, precision_components)

    for line in lines:
        parts = line.strip().split()
        if not parts:
            continue

        if parts[0] == "VERTEX_SE3:QUAT":
            # Format: VERTEX_SE3:QUAT id x y z qx qy qz qw.
            _, idx, x, y, z, qx, qy, qz, qw = parts
            poses.append(
                (
                    float(x),
                    float(y),
                    float(z),
                    float(qx),
                    float(qy),
                    float(qz),
                    float(qw),
                )
            )

        elif parts[0] == "EDGE_SE3:QUAT":
            # Format: EDGE_SE3:QUAT i j x y z qx qy qz qw info_upper_tri(21)
            _, i, j = parts[:3]
            numerical = list(map(float, parts[3:]))
            x, y, z = numerical[0:3]
            qx, qy, qz, qw = numerical[3:7]
            precision_comps = np.array(numerical[7:])
            edges.append((int(i), int(j), x, y, z, qx, qy, qz, qw, precision_comps))

    return {"poses": np.array(poses), "edges": edges}
# Load the sphere2500 dataset.
g2o_path = pathlib.Path("./data/sphere2500.g2o")
data = parse_g2o_se3(g2o_path)

num_poses = len(data["poses"])
num_edges = len(data["edges"])

# Count odometry vs loop closure edges.
odometry_edges = [(i, j) for i, j, *_ in data["edges"] if j == i + 1]
loop_closure_edges = [(i, j) for i, j, *_ in data["edges"] if j != i + 1]

print(f"Poses: {num_poses}")
print(
    f"Edges: {num_edges} ({len(odometry_edges)} odometry, {len(loop_closure_edges)} loop closures)"
)
Poses: 2500
Edges: 4949 (2499 odometry, 2450 loop closures)

Variables and costs#

Use SE3Var for poses on SE(3). Create batched costs for efficient optimization:

# Create batched pose variables.
pose_vars = jaxls.SE3Var(id=jnp.arange(num_poses))


@jaxls.Cost.factory
def between_cost(
    vals: jaxls.VarValues,
    var_a: jaxls.SE3Var,
    var_b: jaxls.SE3Var,
    measured: jaxlie.SE3,
    sqrt_precision: jax.Array,
) -> jax.Array:
    """Cost for relative pose measurement between two poses."""
    T_a = vals[var_a]
    T_b = vals[var_b]
    # Error: measured^{-1} @ (T_a^{-1} @ T_b)
    residual = (measured.inverse() @ (T_a.inverse() @ T_b)).log()
    return sqrt_precision @ residual


@jaxls.Cost.factory(kind="constraint_eq_zero")
def anchor_cost(
    vals: jaxls.VarValues,
    var: jaxls.SE3Var,
    target: jaxlie.SE3,
) -> jax.Array:
    """Anchor the first pose to prevent gauge freedom."""
    return (vals[var].inverse() @ target).log()
# Build edge arrays for batched cost construction.
edge_i = jnp.array([e[0] for e in data["edges"]])
edge_j = jnp.array([e[1] for e in data["edges"]])

# Measured relative poses (g2o uses xyzw quaternion order)
measured_poses = jaxlie.SE3.from_rotation_and_translation(
    rotation=jaxlie.SO3.from_quaternion_xyzw(
        jnp.array([[e[5], e[6], e[7], e[8]] for e in data["edges"]])
    ),
    translation=jnp.array([[e[2], e[3], e[4]] for e in data["edges"]]),
)

# Sqrt precision matrices.
precision_comps = jnp.array([e[9] for e in data["edges"]])
sqrt_precisions = jax.vmap(parse_precision_matrix)(precision_comps)

print(f"Batched edge arrays: {edge_i.shape[0]} edges")
Batched edge arrays: 4949 edges

Solving#

# Initial poses from g2o file (g2o uses xyzw quaternion order)
initial_poses = jaxlie.SE3.from_rotation_and_translation(
    rotation=jaxlie.SO3.from_quaternion_xyzw(
        jnp.array(data["poses"][:, 3:7])  # qx, qy, qz, qw
    ),
    translation=jnp.array(data["poses"][:, 0:3]),  # x, y, z
)

# Create costs using batched construction.
costs: list[jaxls.Cost] = [
    # All between factors in one batched call.
    between_cost(
        jaxls.SE3Var(id=edge_i),
        jaxls.SE3Var(id=edge_j),
        measured_poses,
        sqrt_precisions,
    ),
    # Anchor first pose.
    anchor_cost(
        jaxls.SE3Var(id=0),
        jaxlie.SE3(wxyz_xyz=initial_poses.wxyz_xyz[0]),
    ),
]

initial_vals = jaxls.VarValues.make([pose_vars.with_value(initial_poses)])

# Build and analyze problem.
problem = jaxls.LeastSquaresProblem(costs, [pose_vars])

# Visualize the problem structure structure.
problem.show()
# Analyze the problem.
problem = problem.analyze()
INFO     | Building optimization problem with 4950 terms and 2500 variables: 4949 costs, 1 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 4949 costs, 2 variables each: between_cost
INFO     | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_cost
# Solve with Gauss-Newton.
solution = problem.solve(initial_vals, trust_region=None)
INFO     | Augmented Lagrangian: initial snorm=0.0000e+00, csupn=0.0000e+00, max_rho=1.0000e+07, constraint_dim=6
INFO     |  step #0: cost=2611316.0000 lambd=0.0000 inexact_tol=1.0e-02
INFO     |      - between_cost(4949): 2611316.00000 (avg 87.94086)
INFO     |      - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.65e+04 cost_prev=2611315.2500 cost_new=327240.5000
INFO     |  step #1: cost=327240.4688 lambd=0.0000 inexact_tol=1.0e-02
INFO     |      - between_cost(4949): 327240.46875 (avg 11.02042)
INFO     |      - augmented_anchor_cost(1): 0.00734 (avg 0.00122)
INFO     |      accepted=True ATb_norm=5.87e+03 cost_prev=327240.5000 cost_new=13114.3945
INFO     |  step #2: cost=13114.3926 lambd=0.0000 inexact_tol=1.0e-02
INFO     |      - between_cost(4949): 13114.39258 (avg 0.44165)
INFO     |      - augmented_anchor_cost(1): 0.00139 (avg 0.00023)
INFO     |      accepted=True ATb_norm=1.06e+03 cost_prev=13114.3945 cost_new=1489.4092
INFO     |  step #3: cost=1489.4070 lambd=0.0000 inexact_tol=1.0e-02
INFO     |      - between_cost(4949): 1489.40698 (avg 0.05016)
INFO     |      - augmented_anchor_cost(1): 0.00223 (avg 0.00037)
INFO     |      accepted=True ATb_norm=8.04e+01 cost_prev=1489.4092 cost_new=1376.8256
INFO     |  step #4: cost=1376.8253 lambd=0.0000 inexact_tol=5.2e-03
INFO     |      - between_cost(4949): 1376.82532 (avg 0.04637)
INFO     |      - augmented_anchor_cost(1): 0.00012 (avg 0.00002)
INFO     |      accepted=True ATb_norm=4.46e+01 cost_prev=1376.8256 cost_new=1351.5841
INFO     |  step #5: cost=1351.5840 lambd=0.0000 inexact_tol=5.2e-03
INFO     |      - between_cost(4949): 1351.58398 (avg 0.04552)
INFO     |      - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=2.23e+00 cost_prev=1351.5841 cost_new=1351.4075
INFO     |  step #6: cost=1351.4073 lambd=0.0000 inexact_tol=2.2e-03
INFO     |      - between_cost(4949): 1351.40735 (avg 0.04551)
INFO     |      - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.08e-01 cost_prev=1351.4075 cost_new=1351.4023
INFO     |  AL update: snorm=2.1340e-08, csupn=2.1340e-08, max_rho=1.0000e+07
INFO     | Terminated @ iteration #7: cost=1351.4023 criteria=[1 0 0], term_deltas=3.7e-06,9.3e-03,1.5e-04

Visualization#

Compare the initial trajectory with the optimized result. The poses form a sphere surface with loop closures connecting nearby points:

initial_xyz = np.array(initial_vals[pose_vars].translation())
optimized_xyz = np.array(solution[pose_vars].translation())

Hide code cell source

import contextlib
import io
import viser

# Create Viser server (suppress output).
with (
    contextlib.redirect_stdout(io.StringIO()),
    contextlib.redirect_stderr(io.StringIO()),
):
    server = viser.ViserServer(verbose=False)

# Offset for side-by-side views.
offset = 80.0

# Compute centers for label placement.
initial_center = initial_xyz.mean(axis=0)
optimized_center = optimized_xyz.mean(axis=0)
label_height = max(initial_xyz[:, 2].max(), optimized_xyz[:, 2].max()) + 15

# Set initial camera position for a good view of both trajectories.
server.initial_camera.position = (0.0, -250.0, initial_center[2])
server.initial_camera.look_at = (0.0, 0.0, initial_center[2])

# Add labels above each sphere.
server.scene.add_label(
    "/initial/label",
    "Initial",
    position=(-offset + initial_center[0], initial_center[1], label_height),
    anchor="bottom-center",
)
server.scene.add_label(
    "/optimized/label",
    "Optimized",
    position=(offset + optimized_center[0], optimized_center[1], label_height),
    anchor="bottom-center",
)

# Add initial trajectory (left side).
server.scene.add_point_cloud(
    "/initial/poses",
    points=initial_xyz + np.array([-offset, 0, 0]),
    colors=np.full((len(initial_xyz), 3), [70, 130, 180], dtype=np.uint8),  # Steel blue
    point_size=0.15,
)

# Add odometry edges for initial (connecting neighboring poses).
odom_segments_init = np.array(
    [
        [
            initial_xyz[i] + np.array([-offset, 0, 0]),
            initial_xyz[j] + np.array([-offset, 0, 0]),
        ]
        for i, j in odometry_edges
    ]
)
server.scene.add_line_segments(
    "/initial/odometry",
    points=odom_segments_init,
    colors=np.full(
        (len(odom_segments_init), 2, 3), [70, 130, 180], dtype=np.uint8
    ),  # Steel blue
    line_width=1.0,
)

# Add loop closure edges for initial (subsample for performance).
lc_edges = loop_closure_edges[::10]
lc_segments_init = np.array(
    [
        [
            initial_xyz[i] + np.array([-offset, 0, 0]),
            initial_xyz[j] + np.array([-offset, 0, 0]),
        ]
        for i, j in lc_edges
    ]
)
server.scene.add_line_segments(
    "/initial/loop_closures",
    points=lc_segments_init,
    colors=np.full(
        (len(lc_segments_init), 2, 3), [255, 99, 71], dtype=np.uint8
    ),  # Tomato
    line_width=1.0,
)

# Add optimized trajectory (right side).
server.scene.add_point_cloud(
    "/optimized/poses",
    points=optimized_xyz + np.array([offset, 0, 0]),
    colors=np.full(
        (len(optimized_xyz), 3), [34, 139, 34], dtype=np.uint8
    ),  # Forest green
    point_size=0.15,
)

# Add odometry edges for optimized.
odom_segments_opt = np.array(
    [
        [
            optimized_xyz[i] + np.array([offset, 0, 0]),
            optimized_xyz[j] + np.array([offset, 0, 0]),
        ]
        for i, j in odometry_edges
    ]
)
server.scene.add_line_segments(
    "/optimized/odometry",
    points=odom_segments_opt,
    colors=np.full(
        (len(odom_segments_opt), 2, 3), [34, 139, 34], dtype=np.uint8
    ),  # Forest green
    line_width=1.0,
)

# Add loop closure edges for optimized.
lc_segments_opt = np.array(
    [
        [
            optimized_xyz[i] + np.array([offset, 0, 0]),
            optimized_xyz[j] + np.array([offset, 0, 0]),
        ]
        for i, j in lc_edges
    ]
)
server.scene.add_line_segments(
    "/optimized/loop_closures",
    points=lc_segments_opt,
    colors=np.full(
        (len(lc_segments_opt), 2, 3), [255, 99, 71], dtype=np.uint8
    ),  # Tomato
    line_width=1.0,
)

# Display inline in the notebook.
server.scene.show(height=500)
╭────── viser (listening *:8082) ───────╮
│             ╷                         │
│   HTTP      │ http://localhost:8082   │
│   Websocket │ ws://localhost:8082     │
│             ╵                         │
╰───────────────────────────────────────╯

The optimization refines the noisy initial pose estimates using loop closure constraints. The resulting trajectory forms a clean sphere surface, demonstrating SE(3) pose graph optimization in 3D.

For more on Lie group variables, see jaxls.SE3Var and jaxls.SE2Var.