SE(2) pose graph#

In this notebook, we solve a 2D pose graph optimization problem: estimating robot poses from noisy relative measurements.

Inputs: Relative pose measurements (odometry + loop closures) from g2o file
Outputs: Globally consistent robot trajectory

Features used:

  • SE2Var for SE(2) robot poses

  • @jaxls.Cost.factory with batched edge construction

  • Real g2o dataset with ~3500 poses and loop closures

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import pathlib

import jax
import jax.numpy as jnp
import jaxlie
import jaxls
import numpy as np

Loading the dataset#

Parse the g2o file to extract poses and edges. Edges include both odometry (consecutive poses) and loop closures (revisited locations):

@jax.jit
def parse_precision_matrix(components: jax.Array) -> jax.Array:
    """Convert upper triangular components to sqrt precision matrix.

    Args:
        components: Upper triangular components of the precision matrix (6,)

    Returns:
        Upper Cholesky factor of the precision matrix (3, 3)
    """
    precision = jnp.zeros((3, 3))
    triu_indices = jnp.triu_indices(3)
    precision = precision.at[triu_indices].set(components)
    precision = precision + precision.T - jnp.diag(jnp.diag(precision))
    return jnp.linalg.cholesky(precision).T


def parse_g2o(path: pathlib.Path) -> dict:
    """Parse a 2D g2o file into poses and edges.

    Args:
        path: Path to the g2o file

    Returns:
        Dictionary with 'poses' (N, 3) array and 'edges' list of tuples
    """
    with open(path) as f:
        lines = f.readlines()

    poses = []  # (x, y, theta)
    edges = []  # (i, j, dx, dy, dtheta, precision_components)

    for line in lines:
        parts = line.strip().split()
        if not parts:
            continue

        if parts[0] == "VERTEX_SE2":
            _, idx, x, y, theta = parts
            poses.append((float(x), float(y), float(theta)))

        elif parts[0] == "EDGE_SE2":
            _, i, j = parts[:3]
            dx, dy, dtheta = map(float, parts[3:6])
            precision_comps = list(map(float, parts[6:]))
            edges.append((int(i), int(j), dx, dy, dtheta, np.array(precision_comps)))

    return {"poses": np.array(poses), "edges": edges}
# Load the Manhattan 3500 dataset.
g2o_path = pathlib.Path("./data/input_M3500_g2o.g2o")
data = parse_g2o(g2o_path)

num_poses = len(data["poses"])
num_edges = len(data["edges"])

# Count odometry vs loop closure edges.
odometry_edges = [(i, j) for i, j, *_ in data["edges"] if j == i + 1]
loop_closure_edges = [(i, j) for i, j, *_ in data["edges"] if j != i + 1]

print(f"Poses: {num_poses}")
print(
    f"Edges: {num_edges} ({len(odometry_edges)} odometry, {len(loop_closure_edges)} loop closures)"
)
Poses: 3500
Edges: 5453 (3499 odometry, 1954 loop closures)

Variables and costs#

Use SE2Var for poses on SE(2). The between factor measures the relative pose between two nodes, weighted by a precision matrix.

The sqrt_precision matrix is the upper Cholesky factor of the information (precision) matrix \(\Lambda = \Sigma^{-1}\), where \(\Sigma\) is the measurement covariance. By multiplying the residual by this factor, we get a whitened residual whose squared norm equals the Mahalanobis distance: \(r^T \Lambda \, r = \|L^T r\|^2\).

# Create batched pose variables.
pose_vars = jaxls.SE2Var(id=jnp.arange(num_poses))


@jaxls.Cost.factory
def between_cost(
    vals: jaxls.VarValues,
    var_a: jaxls.SE2Var,
    var_b: jaxls.SE2Var,
    measured: jaxlie.SE2,
    sqrt_precision: jax.Array,
) -> jax.Array:
    """Cost for relative pose measurement between two poses."""
    T_a = vals[var_a]
    T_b = vals[var_b]
    # Error: measured^{-1} @ (T_a^{-1} @ T_b)
    residual = (measured.inverse() @ (T_a.inverse() @ T_b)).log()
    return sqrt_precision @ residual


@jaxls.Cost.factory(kind="constraint_eq_zero")
def anchor_cost(
    vals: jaxls.VarValues,
    var: jaxls.SE2Var,
    target: jaxlie.SE2,
) -> jax.Array:
    """Anchor the first pose to prevent gauge freedom."""
    return (vals[var].inverse() @ target).log()
# Build edge arrays for batched cost construction.
edge_i = jnp.array([e[0] for e in data["edges"]])
edge_j = jnp.array([e[1] for e in data["edges"]])

# Measured relative poses.
measured_poses = jaxlie.SE2.from_xy_theta(
    jnp.array([e[2] for e in data["edges"]]),
    jnp.array([e[3] for e in data["edges"]]),
    jnp.array([e[4] for e in data["edges"]]),
)

# Sqrt precision matrices.
precision_comps = jnp.array([e[5] for e in data["edges"]])
sqrt_precisions = jax.vmap(parse_precision_matrix)(precision_comps)

print(f"Batched edge arrays: {edge_i.shape[0]} edges")
Batched edge arrays: 5453 edges

Solving#

Problem structure#

The problem structure connects pose variables through between factors (odometry and loop closures). With ~3500 poses and ~5600 edges, the visualization automatically limits nodes for performance:

# Create costs using batched construction.
costs: list[jaxls.Cost] = [
    # All between factors in one batched call.
    between_cost(
        jaxls.SE2Var(id=edge_i),
        jaxls.SE2Var(id=edge_j),
        measured_poses,
        sqrt_precisions,
    ),
    # Anchor first pose.
    anchor_cost(
        jaxls.SE2Var(id=0),
        jaxlie.SE2.from_xy_theta(
            data["poses"][0, 0], data["poses"][0, 1], data["poses"][0, 2]
        ),
    ),
]

# Visualize the problem structure structure.
problem = jaxls.LeastSquaresProblem(costs, [pose_vars])
problem.show()
# Initial values from g2o file (odometry integration).
initial_poses = jaxlie.SE2.from_xy_theta(
    jnp.array(data["poses"][:, 0]),
    jnp.array(data["poses"][:, 1]),
    jnp.array(data["poses"][:, 2]),
)
initial_vals = jaxls.VarValues.make([pose_vars.with_value(initial_poses)])

# Analyze the problem.
problem = problem.analyze()
INFO     | Building optimization problem with 5454 terms and 3500 variables: 5453 costs, 1 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 5453 costs, 2 variables each: between_cost
INFO     | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_cost
# Solve with Gauss-Newton (no trust region needed for this well-conditioned problem)
solution = problem.solve(initial_vals, trust_region=None)
INFO     | Augmented Lagrangian: initial snorm=0.0000e+00, csupn=0.0000e+00, max_rho=1.0000e+07, constraint_dim=3
INFO     |  step #0: cost=2634707.7500 lambd=0.0000 inexact_tol=1.0e-02
INFO     |      - between_cost(5453): 2634707.75000 (avg 161.05556)
INFO     |      - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=3.32e+04 cost_prev=2634707.7500 cost_new=57105.4102
INFO     |  step #1: cost=57105.3906 lambd=0.0000 inexact_tol=1.0e-02
INFO     |      - between_cost(5453): 57105.39062 (avg 3.49076)
INFO     |      - augmented_anchor_cost(1): 0.01625 (avg 0.00542)
INFO     |      accepted=True ATb_norm=3.59e+03 cost_prev=57105.4102 cost_new=11499.3945
INFO     |  step #2: cost=11499.3926 lambd=0.0000 inexact_tol=1.0e-02
INFO     |      - between_cost(5453): 11499.39258 (avg 0.70294)
INFO     |      - augmented_anchor_cost(1): 0.00040 (avg 0.00013)
INFO     |      accepted=True ATb_norm=1.24e+03 cost_prev=11499.3945 cost_new=3325.4399
INFO     |  step #3: cost=3325.4395 lambd=0.0000 inexact_tol=1.0e-02
INFO     |      - between_cost(5453): 3325.43945 (avg 0.20328)
INFO     |      - augmented_anchor_cost(1): 0.00016 (avg 0.00005)
INFO     |      accepted=True ATb_norm=4.12e+02 cost_prev=3325.4399 cost_new=11917.8184
INFO     |  step #4: cost=11917.8184 lambd=0.0000 inexact_tol=1.0e-02
INFO     |      - between_cost(5453): 11917.81836 (avg 0.72852)
INFO     |      - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.06e+03 cost_prev=11917.8184 cost_new=2345.9548
INFO     |  step #5: cost=2345.9546 lambd=0.0000 inexact_tol=1.0e-02
INFO     |      - between_cost(5453): 2345.95459 (avg 0.14340)
INFO     |      - augmented_anchor_cost(1): 0.00001 (avg 0.00000)
INFO     |      accepted=True ATb_norm=4.09e+02 cost_prev=2345.9548 cost_new=705.6475
INFO     |  step #6: cost=705.6476 lambd=0.0000 inexact_tol=1.0e-02
INFO     |      - between_cost(5453): 705.64758 (avg 0.04314)
INFO     |      - augmented_anchor_cost(1): 0.00002 (avg 0.00001)
INFO     |      accepted=True ATb_norm=1.58e+02 cost_prev=705.6475 cost_new=1310.0546
INFO     |  step #7: cost=1310.0547 lambd=0.0000 inexact_tol=1.0e-02
INFO     |      - between_cost(5453): 1310.05469 (avg 0.08008)
INFO     |      - augmented_anchor_cost(1): 0.00001 (avg 0.00000)
INFO     |      accepted=True ATb_norm=2.80e+02 cost_prev=1310.0546 cost_new=181.4150
INFO     |  step #8: cost=181.4150 lambd=0.0000 inexact_tol=1.0e-02
INFO     |      - between_cost(5453): 181.41498 (avg 0.01109)
INFO     |      - augmented_anchor_cost(1): 0.00001 (avg 0.00000)
INFO     |      accepted=True ATb_norm=2.18e+01 cost_prev=181.4150 cost_new=381.8472
INFO     |  step #9: cost=381.8471 lambd=0.0000 inexact_tol=5.5e-03
INFO     |      - between_cost(5453): 381.84714 (avg 0.02334)
INFO     |      - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.28e+02 cost_prev=381.8472 cost_new=139.9707
INFO     |  step #10: cost=139.9707 lambd=0.0000 inexact_tol=5.5e-03
INFO     |      - between_cost(5453): 139.97066 (avg 0.00856)
INFO     |      - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=3.30e+00 cost_prev=139.9707 cost_new=139.6545
INFO     |  step #11: cost=139.6545 lambd=0.0000 inexact_tol=6.0e-04
INFO     |      - between_cost(5453): 139.65454 (avg 0.00854)
INFO     |      - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=9.20e+00 cost_prev=139.6545 cost_new=137.9442
INFO     |  step #12: cost=137.9442 lambd=0.0000 inexact_tol=6.0e-04
INFO     |      - between_cost(5453): 137.94420 (avg 0.00843)
INFO     |      - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=7.06e-01 cost_prev=137.9442 cost_new=137.9200
INFO     |  step #13: cost=137.9200 lambd=0.0000 inexact_tol=6.0e-04
INFO     |      - between_cost(5453): 137.91998 (avg 0.00843)
INFO     |      - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=5.22e-01 cost_prev=137.9200 cost_new=137.9158
INFO     |  step #14: cost=137.9158 lambd=0.0000 inexact_tol=6.0e-04
INFO     |      - between_cost(5453): 137.91579 (avg 0.00843)
INFO     |      - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=5.64e-01 cost_prev=137.9158 cost_new=137.9130
INFO     |  step #15: cost=137.9130 lambd=0.0000 inexact_tol=6.0e-04
INFO     |      - between_cost(5453): 137.91296 (avg 0.00843)
INFO     |      - augmented_anchor_cost(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=5.52e-01 cost_prev=137.9130 cost_new=137.9128
INFO     |  AL update: snorm=4.8128e-09, csupn=4.8128e-09, max_rho=1.0000e+07
INFO     | Terminated @ iteration #16: cost=137.9128 criteria=[1 0 0], term_deltas=8.9e-07,1.3e-01,2.2e-05

Visualization#

Compare the initial odometry-only trajectory with the optimized result. Loop closures correct drift accumulated from odometry:

initial_xy = np.array(initial_vals[pose_vars].translation())
optimized_xy = np.array(solution[pose_vars].translation())

Hide code cell source

import plotly.graph_objects as go
from plotly.subplots import make_subplots


def get_trajectory_trace(
    positions: np.ndarray,
    name: str,
    color: str,
) -> go.Scatter:
    """Create trajectory trace.

    Args:
        positions: Position array (N, 2)
        name: Trace name for legend
        color: Line/marker color

    Returns:
        Plotly Scatter trace
    """
    return go.Scatter(
        x=positions[:, 0],
        y=positions[:, 1],
        mode="lines",
        line=dict(color=color, width=1.5),
        name=name,
        hovertemplate="(%{x:.1f}, %{y:.1f})<extra></extra>",
    )


def get_loop_closure_traces(
    positions: np.ndarray,
    edges: list[tuple[int, int]],
    color: str,
    max_edges: int = 200,
) -> list[go.Scatter]:
    """Create loop closure edge traces (subsample for performance).

    Args:
        positions: Position array (N, 2)
        edges: List of (i, j) index pairs for loop closure edges
        color: Line color
        max_edges: Maximum number of edges to display

    Returns:
        List containing a single Plotly Scatter trace for all edges
    """
    # Subsample if too many.
    step = max(1, len(edges) // max_edges)
    sampled = edges[::step]

    x_coords = []
    y_coords = []
    for i, j in sampled:
        x_coords.extend([positions[i, 0], positions[j, 0], None])
        y_coords.extend([positions[i, 1], positions[j, 1], None])

    return [
        go.Scatter(
            x=x_coords,
            y=y_coords,
            mode="lines",
            line=dict(color=color, width=0.5),
            opacity=0.4,
            name="Loop closures",
            hoverinfo="skip",
        )
    ]

Hide code cell source

from IPython.display import HTML

fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Initial (odometry only)", "Optimized (with loop closures)"),
)

# Initial trajectory with loop closures shown.
for trace in get_loop_closure_traces(initial_xy, loop_closure_edges, "tomato"):
    fig.add_trace(trace, row=1, col=1)
fig.add_trace(get_trajectory_trace(initial_xy, "Trajectory", "steelblue"), row=1, col=1)

# Optimized trajectory.
for trace in get_loop_closure_traces(optimized_xy, loop_closure_edges, "tomato"):
    fig.add_trace(trace, row=1, col=2)
fig.add_trace(
    get_trajectory_trace(optimized_xy, "Trajectory", "forestgreen"), row=1, col=2
)

# Compute shared bounds.
all_xy = np.concatenate([initial_xy, optimized_xy])
x_min, x_max = all_xy[:, 0].min(), all_xy[:, 0].max()
y_min, y_max = all_xy[:, 1].min(), all_xy[:, 1].max()
padding = 0.05 * max(x_max - x_min, y_max - y_min)

# Layout with equal aspect and shared bounds.
fig.update_xaxes(
    title_text="x (m)",
    scaleanchor="y",
    scaleratio=1,
    range=[x_min - padding, x_max + padding],
)
fig.update_yaxes(
    title_text="y (m)",
    range=[y_min - padding, y_max + padding],
)
fig.update_layout(
    height=500,
    showlegend=False,
    margin=dict(t=40, b=40, l=40, r=40),
)
HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

The optimization corrects drift accumulated from odometry-only integration. Loop closures (shown in red) constrain revisited locations to be consistent, resulting in a globally coherent map.

For more on Lie group variables, see jaxls.SE2Var and jaxls.SE3Var.