Bundle adjustment#

In this notebook, we solve a bundle adjustment problem: jointly refining camera poses and 3D point positions to minimize reprojection error.

Inputs: Initial camera poses, 3D points, and 2D observations from BAL dataset
Outputs: Refined camera poses and 3D point cloud

Features used:

  • SE3Var for camera poses (batched)

  • Custom Point3Var for 3D landmark positions

  • Batched reprojection costs with Huber loss

  • Trust region solver for large-scale optimization

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import jax
import jax.numpy as jnp
import jaxlie
import jaxls

Load BAL dataset#

Download and parse a problem from the Dubrovnik dataset. The BAL format stores:

  • Camera parameters: Rodrigues rotation (3), translation (3), focal length (1), distortion k1, k2 (2)

  • 3D point positions

  • 2D observations with camera and point indices

Hide code cell source

import bz2
import urllib.request
from pathlib import Path


def download_bal_dataset(url: str, cache_dir: Path = Path("/tmp/bal_data")) -> Path:
    """Download and cache a BAL dataset.

    Args:
        url: URL to download the dataset from
        cache_dir: Directory to cache downloaded files

    Returns:
        Path to the downloaded (cached) file
    """
    cache_dir.mkdir(parents=True, exist_ok=True)
    filename = url.split("/")[-1]
    local_path = cache_dir / filename

    if not local_path.exists():
        logger.info(f"Downloading {filename}...")
        urllib.request.urlretrieve(url, local_path)
    return local_path


def parse_bal_file(
    path: Path,
) -> tuple[
    jax.Array,  # camera_params: (n_cameras, 9)
    jax.Array,  # points_3d: (n_points, 3)
    jax.Array,  # observations: (n_obs, 2) pixel coords
    jax.Array,  # camera_indices: (n_obs,)
    jax.Array,  # point_indices: (n_obs,)
]:
    """Parse a BAL dataset file.

    Args:
        path: Path to the BAL dataset file (bz2 compressed)

    Returns:
        Tuple of (camera_params, points_3d, observations, camera_indices, point_indices)
    """
    with bz2.open(path, "rt") as f:
        n_cameras, n_points, n_obs = map(int, f.readline().split())

        # Read observations.
        camera_indices = []
        point_indices = []
        observations = []
        for _ in range(n_obs):
            parts = f.readline().split()
            camera_indices.append(int(parts[0]))
            point_indices.append(int(parts[1]))
            observations.append([float(parts[2]), float(parts[3])])

        # Read camera parameters (9 values each)
        camera_params = []
        for _ in range(n_cameras):
            params = [float(f.readline()) for _ in range(9)]
            camera_params.append(params)

        # Read 3D points (3 values each)
        points_3d = []
        for _ in range(n_points):
            point = [float(f.readline()) for _ in range(3)]
            points_3d.append(point)

    return (
        jnp.array(camera_params),
        jnp.array(points_3d),
        jnp.array(observations),
        jnp.array(camera_indices),
        jnp.array(point_indices),
    )
# Download Dubrovnik dataset (16 cameras, 22,106 points)
bal_url = "https://grail.cs.washington.edu/projects/bal/data/dubrovnik/problem-16-22106-pre.txt.bz2"
bal_path = download_bal_dataset(bal_url)

camera_params, points_3d, observations, camera_indices, point_indices = parse_bal_file(
    bal_path
)
n_cameras, n_points, n_obs = (
    camera_params.shape[0],
    points_3d.shape[0],
    observations.shape[0],
)

print(f"camera_params:  {camera_params.shape}")
print(f"points_3d:      {points_3d.shape}")
print(f"observations:   {observations.shape}")
print(f"camera_indices: {camera_indices.shape}")
print(f"point_indices:  {point_indices.shape}")
camera_params:  (16, 9)
points_3d:      (22106, 3)
observations:   (83718, 2)
camera_indices: (83718,)
point_indices:  (83718,)

Camera model#

The BAL camera model uses:

  • World-to-camera: \(P = R \cdot X + t\)

  • Perspective projection: \(p = -P_{xy} / P_z\)

  • Radial distortion: \(r(p) = 1 + k_1 ||p||^2 + k_2 ||p||^4\)

  • Final projection: \(p' = f \cdot r(p) \cdot p\)

@jax.jit
def project_point(
    point_world: jax.Array,
    T_camera_world: jaxlie.SE3,
    focal: float,
    k1: float,
    k2: float,
) -> jax.Array:
    """Project 3D point to 2D using BAL camera model.

    Args:
        point_world: 3D point in world frame (3,)
        T_camera_world: Camera pose (world to camera transform)
        focal: Focal length
        k1: First radial distortion coefficient
        k2: Second radial distortion coefficient

    Returns:
        2D projected point (2,)
    """
    # Transform point to camera frame.
    point_cam = T_camera_world @ point_world

    # Perspective projection (BAL convention: -P_xy / P_z)
    p = -point_cam[:2] / point_cam[2]

    # Radial distortion.
    r_sq = jnp.sum(p**2)
    distortion = 1.0 + k1 * r_sq + k2 * r_sq**2

    return focal * distortion * p

BAL to SE3 conversion#

BAL stores camera extrinsics as Rodrigues rotation (axis-angle) + translation. We use jaxlie.SO3.exp() to convert the Rodrigues vector directly to SO3:

@jax.jit
def bal_params_to_se3(params: jax.Array) -> jaxlie.SE3:
    """Convert BAL camera parameters to SE3 pose.

    BAL stores rotation as Rodrigues vector (axis-angle).

    Args:
        params: Camera parameters [rodrigues(3), translation(3)]

    Returns:
        SE3 camera pose
    """
    return jaxlie.SE3.from_rotation_and_translation(
        rotation=jaxlie.SO3.exp(params[:3]),
        translation=params[3:6],
    )


# Convert all cameras.
initial_poses = jax.vmap(bal_params_to_se3)(camera_params[:, :6])
focal_lengths = camera_params[:, 6]
distortion_k1 = camera_params[:, 7]
distortion_k2 = camera_params[:, 8]

# Add noise to 3D points to make optimization more interesting.
key = jax.random.PRNGKey(0)
point_noise = jax.random.normal(key, points_3d.shape) * 0.5
points_3d_noisy = points_3d + point_noise

print(f"Initial poses shape: {initial_poses.wxyz_xyz.shape}")
print(
    f"Focal lengths range: [{float(focal_lengths.min()):.1f}, {float(focal_lengths.max()):.1f}]"
)
print("Added noise to 3D points: std=0.5 units")
Initial poses shape: (16, 7)
Focal lengths range: [809.5, 1918.2]
Added noise to 3D points: std=0.5 units

Variables and costs#

class Point3Var(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(3)):
    """3D landmark position."""


# Create batched variables.
camera_vars = jaxls.SE3Var(id=jnp.arange(n_cameras))
point_vars = Point3Var(id=jnp.arange(n_points))

print(f"Camera variables: {n_cameras}")
print(f"Point variables: {n_points}")
Camera variables: 16
Point variables: 22106
@jaxls.Cost.factory
def reprojection_cost(
    vals: jaxls.VarValues,
    camera_var: jaxls.SE3Var,
    point_var: Point3Var,
    observed_px: jax.Array,
    focal: float,
    k1: float,
    k2: float,
) -> jax.Array:
    """Reprojection error with Huber loss for robustness."""
    pose = vals[camera_var]
    point = vals[point_var]
    projected = project_point(point, pose, focal, k1, k2)
    residual = projected - observed_px

    # IRLS-style Huber weighting for robustness to outliers.
    # For |r| <= delta: weight = 1 (quadratic region)
    # For |r| > delta: weight = delta / |r| (linear region)
    # stop_gradient prevents instabilities from differentiating through weights.
    delta = 2.0  # pixels
    abs_r = jnp.abs(residual) + 1e-8
    weight = jax.lax.stop_gradient(jnp.where(abs_r > delta, delta / abs_r, 1.0))
    return residual * jnp.sqrt(weight)

Problem construction#

Create batched reprojection costs for all observations:

# Create a single batched cost for all observations.
costs = [
    reprojection_cost(
        jaxls.SE3Var(id=camera_indices),
        Point3Var(id=point_indices),
        observations,
        focal_lengths[camera_indices],
        distortion_k1[camera_indices],
        distortion_k2[camera_indices],
    )
]

print(f"Created 1 batched cost representing {n_obs} observations")
Created 1 batched cost representing 83718 observations
# Initial values (with added noise)
initial_vals = jaxls.VarValues.make(
    [
        camera_vars.with_value(initial_poses),
        point_vars.with_value(points_3d_noisy),
    ]
)

# Create the problem.
problem = jaxls.LeastSquaresProblem(costs, [camera_vars, point_vars])

# Visualize the problem structure structure.
problem.show()
# Analyze the problem.
problem = problem.analyze()
INFO     | Building optimization problem with 83718 terms and 22122 variables: 83718 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 83718 costs, 2 variables each: reprojection_cost

Solving#

solution = problem.solve(initial_vals, linear_solver="cholmod")
INFO     |  step #0: cost=6650404.5000 lambd=0.0005
INFO     |      - reprojection_cost(83718): 6650404.50000 (avg 39.71908)
INFO     |      accepted=True ATb_norm=1.16e+07 cost_prev=6650404.5000 cost_new=550783.3750
INFO     |  step #1: cost=550783.3750 lambd=0.0003
INFO     |      - reprojection_cost(83718): 550783.37500 (avg 3.28952)
INFO     |      accepted=True ATb_norm=3.57e+06 cost_prev=550783.3750 cost_new=409055.8438
INFO     |  step #2: cost=409055.8438 lambd=0.0001
INFO     |      - reprojection_cost(83718): 409055.84375 (avg 2.44306)
INFO     |      accepted=True ATb_norm=1.68e+06 cost_prev=409055.8438 cost_new=395470.8750
INFO     |  step #3: cost=395470.8750 lambd=0.0001
INFO     |      - reprojection_cost(83718): 395470.87500 (avg 2.36192)
INFO     |  step #4: cost=395470.8750 lambd=0.0001
INFO     |      - reprojection_cost(83718): 395470.87500 (avg 2.36192)
INFO     |  step #5: cost=395470.8750 lambd=0.0003
INFO     |      - reprojection_cost(83718): 395470.87500 (avg 2.36192)
INFO     |      accepted=True ATb_norm=7.23e+05 cost_prev=395470.8750 cost_new=391111.0000
INFO     |  step #6: cost=391111.0000 lambd=0.0001
INFO     |      - reprojection_cost(83718): 391111.00000 (avg 2.33588)
INFO     |      accepted=True ATb_norm=2.89e+05 cost_prev=391111.0000 cost_new=389297.6562
INFO     |  step #7: cost=389297.6562 lambd=0.0001
INFO     |      - reprojection_cost(83718): 389297.65625 (avg 2.32505)
INFO     |  step #8: cost=389297.6562 lambd=0.0001
INFO     |      - reprojection_cost(83718): 389297.65625 (avg 2.32505)
INFO     |      accepted=True ATb_norm=1.92e+05 cost_prev=389297.6562 cost_new=388351.4062
INFO     |  step #9: cost=388351.4062 lambd=0.0001
INFO     |      - reprojection_cost(83718): 388351.40625 (avg 2.31940)
INFO     |  step #10: cost=388351.4062 lambd=0.0001
INFO     |      - reprojection_cost(83718): 388351.40625 (avg 2.31940)
INFO     |  step #11: cost=388351.4062 lambd=0.0003
INFO     |      - reprojection_cost(83718): 388351.40625 (avg 2.31940)
INFO     |      accepted=True ATb_norm=1.05e+05 cost_prev=388351.4062 cost_new=387798.1562
INFO     |  step #12: cost=387798.1562 lambd=0.0001
INFO     |      - reprojection_cost(83718): 387798.15625 (avg 2.31610)
INFO     |  step #13: cost=387798.1562 lambd=0.0003
INFO     |      - reprojection_cost(83718): 387798.15625 (avg 2.31610)
INFO     |  step #14: cost=387798.1562 lambd=0.0005
INFO     |      - reprojection_cost(83718): 387798.15625 (avg 2.31610)
INFO     |  step #15: cost=387798.1562 lambd=0.0010
INFO     |      - reprojection_cost(83718): 387798.15625 (avg 2.31610)
INFO     |      accepted=True ATb_norm=1.04e+05 cost_prev=387798.1562 cost_new=387432.3125
INFO     |  step #16: cost=387432.3125 lambd=0.0005
INFO     |      - reprojection_cost(83718): 387432.31250 (avg 2.31391)
INFO     |  step #17: cost=387432.3125 lambd=0.0010
INFO     |      - reprojection_cost(83718): 387432.31250 (avg 2.31391)
INFO     |      accepted=True ATb_norm=5.00e+04 cost_prev=387432.3125 cost_new=387194.6250
INFO     |  step #18: cost=387194.6250 lambd=0.0005
INFO     |      - reprojection_cost(83718): 387194.62500 (avg 2.31249)
INFO     |      accepted=True ATb_norm=3.83e+04 cost_prev=387194.6250 cost_new=387021.4375
INFO     |  step #19: cost=387021.4375 lambd=0.0003
INFO     |      - reprojection_cost(83718): 387021.43750 (avg 2.31146)
INFO     |  step #20: cost=387021.4375 lambd=0.0005
INFO     |      - reprojection_cost(83718): 387021.43750 (avg 2.31146)
INFO     |  step #21: cost=387021.4375 lambd=0.0010
INFO     |      - reprojection_cost(83718): 387021.43750 (avg 2.31146)
INFO     |      accepted=True ATb_norm=3.88e+04 cost_prev=387021.4375 cost_new=386884.6250
INFO     |  step #22: cost=386884.6250 lambd=0.0005
INFO     |      - reprojection_cost(83718): 386884.62500 (avg 2.31064)
INFO     |  step #23: cost=386884.6250 lambd=0.0010
INFO     |      - reprojection_cost(83718): 386884.62500 (avg 2.31064)
INFO     |  step #24: cost=386884.6250 lambd=0.0020
INFO     |      - reprojection_cost(83718): 386884.62500 (avg 2.31064)
INFO     |      accepted=True ATb_norm=1.94e+04 cost_prev=386884.6250 cost_new=386786.8750
INFO     |  step #25: cost=386786.8750 lambd=0.0010
INFO     |      - reprojection_cost(83718): 386786.87500 (avg 2.31006)
INFO     |  step #26: cost=386786.8750 lambd=0.0020
INFO     |      - reprojection_cost(83718): 386786.87500 (avg 2.31006)
INFO     |      accepted=True ATb_norm=1.65e+04 cost_prev=386786.8750 cost_new=386715.5625
INFO     |  step #27: cost=386715.5625 lambd=0.0010
INFO     |      - reprojection_cost(83718): 386715.56250 (avg 2.30963)
INFO     |  step #28: cost=386715.5625 lambd=0.0020
INFO     |      - reprojection_cost(83718): 386715.56250 (avg 2.30963)
INFO     |  step #29: cost=386715.5625 lambd=0.0040
INFO     |      - reprojection_cost(83718): 386715.56250 (avg 2.30963)
INFO     |      accepted=True ATb_norm=1.57e+04 cost_prev=386715.5625 cost_new=386657.0000
INFO     | Terminated @ iteration #30: cost=386657.0000 criteria=[0 0 1], term_deltas=1.5e-04,6.5e+01,9.1e-07

Reprojection error analysis#

@jax.jit
def compute_reprojection_errors(
    vals: jaxls.VarValues,
    cam_idx: jax.Array,
    pt_idx: jax.Array,
    obs: jax.Array,
) -> jax.Array:
    """Compute reprojection errors for all observations.

    Args:
        vals: Variable values containing camera poses and 3D points
        cam_idx: Camera indices for each observation (n_obs,)
        pt_idx: Point indices for each observation (n_obs,)
        obs: 2D observations (n_obs, 2)

    Returns:
        Reprojection errors for each observation (n_obs,)
    """

    def single_error(
        c_idx: jax.Array, p_idx: jax.Array, observed: jax.Array
    ) -> jax.Array:
        pose = vals[jaxls.SE3Var(id=c_idx)]
        point = vals[Point3Var(id=p_idx)]
        projected = project_point(
            point,
            pose,
            focal_lengths[c_idx],
            distortion_k1[c_idx],
            distortion_k2[c_idx],
        )
        return jnp.linalg.norm(projected - observed)

    return jax.vmap(single_error)(cam_idx, pt_idx, obs)


errors_initial = compute_reprojection_errors(
    initial_vals, camera_indices, point_indices, observations
)
errors_final = compute_reprojection_errors(
    solution, camera_indices, point_indices, observations
)

print(
    f"Initial reprojection error: mean={float(errors_initial.mean()):.2f}px, median={float(jnp.median(errors_initial)):.2f}px"
)
print(
    f"Final reprojection error:   mean={float(errors_final.mean()):.2f}px, median={float(jnp.median(errors_final)):.2f}px"
)
Initial reprojection error: mean=31.24px, median=26.71px
Final reprojection error:   mean=2.28px, median=1.65px

Visualization#

Hide code cell source

import contextlib
import io
import numpy as np
import viser

# Compute per-point mean reprojection error for filtering.
point_errors = np.zeros(n_points)
point_counts = np.zeros(n_points)
for i, (pt_idx, err) in enumerate(zip(np.array(point_indices), np.array(errors_final))):
    point_errors[pt_idx] += err
    point_counts[pt_idx] += 1
point_mean_errors = point_errors / np.maximum(point_counts, 1)

# Filter to points with low reprojection error.
error_threshold = 1.0  # pixels
good_point_mask = point_mean_errors < error_threshold
good_point_indices = np.where(good_point_mask)[0]

points_init = np.array(initial_vals[Point3Var])
points_opt = np.array(solution[Point3Var])

# Subsample filtered points for visualization.
subsample = max(1, len(good_point_indices) // 30_000)
vis_indices = good_point_indices[::subsample]

print(f"Showing {len(vis_indices)} of {n_points} points (error < {error_threshold}px)")

# Create Viser server (suppress output).
with (
    contextlib.redirect_stdout(io.StringIO()),
    contextlib.redirect_stderr(io.StringIO()),
):
    server = viser.ViserServer(verbose=False)
server.scene.set_up_direction("+z")

# Offset for side-by-side views. Center between the two is at origin.
offset = 20.0

# Center the point clouds at origin.
points_center = points_opt[vis_indices].mean(axis=0)
label_height = points_opt[vis_indices][:, 2].max() - points_center[2] + 5.0

# Set initial camera position for a good view of both point clouds.
server.initial_camera.position = (0.0, -80.0, 30.0)
server.initial_camera.look_at = (0.0, 0.0, 0.0)

# Add labels (optimized on left, initial on right).
server.scene.add_label(
    "/optimized_label",
    text="Optimized",
    position=(-offset, 0.0, label_height),
    anchor="bottom-center",
)
server.scene.add_label(
    "/initial_label",
    text="Initial (with noise)",
    position=(offset, 0.0, label_height),
    anchor="bottom-center",
)

# Add optimized point cloud (left side).
server.scene.add_point_cloud(
    "/optimized/points",
    points=points_opt[vis_indices] - points_center + np.array([-offset, 0, 0]),
    colors=np.full((len(vis_indices), 3), [70, 130, 180], dtype=np.uint8),  # Steel blue
    point_size=0.08,
)

# Add initial point cloud (right side).
server.scene.add_point_cloud(
    "/initial/points",
    points=points_init[vis_indices] - points_center + np.array([offset, 0, 0]),
    colors=np.full((len(vis_indices), 3), 150, dtype=np.uint8),  # Gray
    point_size=0.08,
)

# Add camera frustums.
# BAL stores T_camera_world (world-to-camera transform).
# BAL cameras look down -Z, but Viser expects +Z, so we flip by 180° around X.
poses_init = initial_vals[jaxls.SE3Var]
poses_opt = solution[jaxls.SE3Var]

# 180° rotation around X-axis to flip from -Z to +Z forward direction.
flip_rotation = jaxlie.SO3.from_x_radians(np.pi)

for i in range(n_cameras):
    # Optimized camera (left side).
    pose_opt = jaxlie.SE3(wxyz_xyz=poses_opt.wxyz_xyz[i])
    T_world_camera_opt = pose_opt.inverse()
    cam_pos_opt = (
        np.array(T_world_camera_opt.translation())
        - points_center
        + np.array([-offset, 0, 0])
    )
    cam_rot_opt = T_world_camera_opt.rotation() @ flip_rotation

    server.scene.add_camera_frustum(
        f"/optimized/camera_{i}",
        fov=np.pi / 3,
        aspect=1.5,
        scale=0.5,
        wxyz=np.array(cam_rot_opt.wxyz),
        position=cam_pos_opt,
        color=(34, 139, 34),  # Forest green
    )

    # Initial camera (right side).
    pose_init = jaxlie.SE3(wxyz_xyz=poses_init.wxyz_xyz[i])
    T_world_camera_init = pose_init.inverse()
    cam_pos_init = (
        np.array(T_world_camera_init.translation())
        - points_center
        + np.array([offset, 0, 0])
    )
    cam_rot_init = T_world_camera_init.rotation() @ flip_rotation

    server.scene.add_camera_frustum(
        f"/initial/camera_{i}",
        fov=np.pi / 3,
        aspect=1.5,
        scale=0.5,
        wxyz=np.array(cam_rot_init.wxyz),
        position=cam_pos_init,
        color=(255, 99, 71),  # Tomato
    )

# Display inline in the notebook.
server.scene.show(height=500)
╭────── viser (listening *:8082) ───────╮
│             ╷                         │
│   HTTP      │ http://localhost:8082   │
│   Websocket │ ws://localhost:8082     │
│             ╵                         │
╰───────────────────────────────────────╯
Showing 5297 of 22106 points (error < 1.0px)

Hide code cell source

import plotly.graph_objects as go
from IPython.display import HTML

# Reprojection error histogram.
fig_hist = go.Figure()

# Clip for visualization (outliers compress the histogram).
max_error = 50.0
errors_initial_clipped = jnp.clip(errors_initial, 0, max_error)
errors_final_clipped = jnp.clip(errors_final, 0, max_error)

fig_hist.add_trace(
    go.Histogram(
        x=errors_initial_clipped,
        name="Initial",
        opacity=0.7,
        nbinsx=50,
        marker_color="lightcoral",
    )
)

fig_hist.add_trace(
    go.Histogram(
        x=errors_final_clipped,
        name="Optimized",
        opacity=0.7,
        nbinsx=50,
        marker_color="steelblue",
    )
)

fig_hist.update_layout(
    title="Reprojection Error Distribution",
    xaxis=dict(title="Reprojection Error (pixels)", range=[0, max_error]),
    yaxis_title="Count",
    barmode="overlay",
    height=350,
    margin=dict(t=40, b=40, l=60, r=40),
    legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
)

HTML(fig_hist.to_html(full_html=False, include_plotlyjs="cdn"))

Bundle adjustment jointly refines camera poses and 3D point positions to minimize reprojection error. The Huber loss provides robustness to outliers, which is important for real-world data.

For solver configuration, see jaxls.TrustRegionConfig. For Lie group variables, see jaxls.SE3Var.