Bundle adjustment#
In this notebook, we solve a bundle adjustment problem: jointly refining camera poses and 3D point positions to minimize reprojection error.
Inputs: Initial camera poses, 3D points, and 2D observations from BAL dataset
Outputs: Refined camera poses and 3D point cloud
Features used:
SE3Varfor camera poses (batched)Custom
Point3Varfor 3D landmark positionsBatched reprojection costs with Huber loss
Trust region solver for large-scale optimization
import jax
import jax.numpy as jnp
import jaxlie
import jaxls
Load BAL dataset#
Download and parse a problem from the Dubrovnik dataset. The BAL format stores:
Camera parameters: Rodrigues rotation (3), translation (3), focal length (1), distortion k1, k2 (2)
3D point positions
2D observations with camera and point indices
# Download Dubrovnik dataset (16 cameras, 22,106 points)
bal_url = "https://grail.cs.washington.edu/projects/bal/data/dubrovnik/problem-16-22106-pre.txt.bz2"
bal_path = download_bal_dataset(bal_url)
camera_params, points_3d, observations, camera_indices, point_indices = parse_bal_file(
bal_path
)
n_cameras, n_points, n_obs = (
camera_params.shape[0],
points_3d.shape[0],
observations.shape[0],
)
print(f"camera_params: {camera_params.shape}")
print(f"points_3d: {points_3d.shape}")
print(f"observations: {observations.shape}")
print(f"camera_indices: {camera_indices.shape}")
print(f"point_indices: {point_indices.shape}")
camera_params: (16, 9)
points_3d: (22106, 3)
observations: (83718, 2)
camera_indices: (83718,)
point_indices: (83718,)
Camera model#
The BAL camera model uses:
World-to-camera: \(P = R \cdot X + t\)
Perspective projection: \(p = -P_{xy} / P_z\)
Radial distortion: \(r(p) = 1 + k_1 ||p||^2 + k_2 ||p||^4\)
Final projection: \(p' = f \cdot r(p) \cdot p\)
@jax.jit
def project_point(
point_world: jax.Array,
T_camera_world: jaxlie.SE3,
focal: float,
k1: float,
k2: float,
) -> jax.Array:
"""Project 3D point to 2D using BAL camera model.
Args:
point_world: 3D point in world frame (3,)
T_camera_world: Camera pose (world to camera transform)
focal: Focal length
k1: First radial distortion coefficient
k2: Second radial distortion coefficient
Returns:
2D projected point (2,)
"""
# Transform point to camera frame.
point_cam = T_camera_world @ point_world
# Perspective projection (BAL convention: -P_xy / P_z)
p = -point_cam[:2] / point_cam[2]
# Radial distortion.
r_sq = jnp.sum(p**2)
distortion = 1.0 + k1 * r_sq + k2 * r_sq**2
return focal * distortion * p
BAL to SE3 conversion#
BAL stores camera extrinsics as Rodrigues rotation (axis-angle) + translation. We use jaxlie.SO3.exp() to convert the Rodrigues vector directly to SO3:
@jax.jit
def bal_params_to_se3(params: jax.Array) -> jaxlie.SE3:
"""Convert BAL camera parameters to SE3 pose.
BAL stores rotation as Rodrigues vector (axis-angle).
Args:
params: Camera parameters [rodrigues(3), translation(3)]
Returns:
SE3 camera pose
"""
return jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.exp(params[:3]),
translation=params[3:6],
)
# Convert all cameras.
initial_poses = jax.vmap(bal_params_to_se3)(camera_params[:, :6])
focal_lengths = camera_params[:, 6]
distortion_k1 = camera_params[:, 7]
distortion_k2 = camera_params[:, 8]
# Add noise to 3D points to make optimization more interesting.
key = jax.random.PRNGKey(0)
point_noise = jax.random.normal(key, points_3d.shape) * 0.5
points_3d_noisy = points_3d + point_noise
print(f"Initial poses shape: {initial_poses.wxyz_xyz.shape}")
print(
f"Focal lengths range: [{float(focal_lengths.min()):.1f}, {float(focal_lengths.max()):.1f}]"
)
print("Added noise to 3D points: std=0.5 units")
Initial poses shape: (16, 7)
Focal lengths range: [809.5, 1918.2]
Added noise to 3D points: std=0.5 units
Variables and costs#
class Point3Var(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(3)):
"""3D landmark position."""
# Create batched variables.
camera_vars = jaxls.SE3Var(id=jnp.arange(n_cameras))
point_vars = Point3Var(id=jnp.arange(n_points))
print(f"Camera variables: {n_cameras}")
print(f"Point variables: {n_points}")
Camera variables: 16
Point variables: 22106
@jaxls.Cost.factory
def reprojection_cost(
vals: jaxls.VarValues,
camera_var: jaxls.SE3Var,
point_var: Point3Var,
observed_px: jax.Array,
focal: float,
k1: float,
k2: float,
) -> jax.Array:
"""Reprojection error with Huber loss for robustness."""
pose = vals[camera_var]
point = vals[point_var]
projected = project_point(point, pose, focal, k1, k2)
residual = projected - observed_px
# IRLS-style Huber weighting for robustness to outliers.
# For |r| <= delta: weight = 1 (quadratic region)
# For |r| > delta: weight = delta / |r| (linear region)
# stop_gradient prevents instabilities from differentiating through weights.
delta = 2.0 # pixels
abs_r = jnp.abs(residual) + 1e-8
weight = jax.lax.stop_gradient(jnp.where(abs_r > delta, delta / abs_r, 1.0))
return residual * jnp.sqrt(weight)
Problem construction#
Create batched reprojection costs for all observations:
# Create a single batched cost for all observations.
costs = [
reprojection_cost(
jaxls.SE3Var(id=camera_indices),
Point3Var(id=point_indices),
observations,
focal_lengths[camera_indices],
distortion_k1[camera_indices],
distortion_k2[camera_indices],
)
]
print(f"Created 1 batched cost representing {n_obs} observations")
Created 1 batched cost representing 83718 observations
# Initial values (with added noise)
initial_vals = jaxls.VarValues.make(
[
camera_vars.with_value(initial_poses),
point_vars.with_value(points_3d_noisy),
]
)
# Create the problem.
problem = jaxls.LeastSquaresProblem(costs, [camera_vars, point_vars])
# Visualize the problem structure structure.
problem.show()
# Analyze the problem.
problem = problem.analyze()
INFO | Building optimization problem with 83718 terms and 22122 variables: 83718 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO | Vectorizing group with 83718 costs, 2 variables each: reprojection_cost
Solving#
solution = problem.solve(initial_vals, linear_solver="cholmod")
INFO | step #0: cost=6650404.5000 lambd=0.0005
INFO | - reprojection_cost(83718): 6650404.50000 (avg 39.71908)
INFO | accepted=True ATb_norm=1.16e+07 cost_prev=6650404.5000 cost_new=550783.3750
INFO | step #1: cost=550783.3750 lambd=0.0003
INFO | - reprojection_cost(83718): 550783.37500 (avg 3.28952)
INFO | accepted=True ATb_norm=3.57e+06 cost_prev=550783.3750 cost_new=409055.8438
INFO | step #2: cost=409055.8438 lambd=0.0001
INFO | - reprojection_cost(83718): 409055.84375 (avg 2.44306)
INFO | accepted=True ATb_norm=1.68e+06 cost_prev=409055.8438 cost_new=395470.8750
INFO | step #3: cost=395470.8750 lambd=0.0001
INFO | - reprojection_cost(83718): 395470.87500 (avg 2.36192)
INFO | step #4: cost=395470.8750 lambd=0.0001
INFO | - reprojection_cost(83718): 395470.87500 (avg 2.36192)
INFO | step #5: cost=395470.8750 lambd=0.0003
INFO | - reprojection_cost(83718): 395470.87500 (avg 2.36192)
INFO | accepted=True ATb_norm=7.23e+05 cost_prev=395470.8750 cost_new=391111.0000
INFO | step #6: cost=391111.0000 lambd=0.0001
INFO | - reprojection_cost(83718): 391111.00000 (avg 2.33588)
INFO | accepted=True ATb_norm=2.89e+05 cost_prev=391111.0000 cost_new=389297.6562
INFO | step #7: cost=389297.6562 lambd=0.0001
INFO | - reprojection_cost(83718): 389297.65625 (avg 2.32505)
INFO | step #8: cost=389297.6562 lambd=0.0001
INFO | - reprojection_cost(83718): 389297.65625 (avg 2.32505)
INFO | accepted=True ATb_norm=1.92e+05 cost_prev=389297.6562 cost_new=388351.4062
INFO | step #9: cost=388351.4062 lambd=0.0001
INFO | - reprojection_cost(83718): 388351.40625 (avg 2.31940)
INFO | step #10: cost=388351.4062 lambd=0.0001
INFO | - reprojection_cost(83718): 388351.40625 (avg 2.31940)
INFO | step #11: cost=388351.4062 lambd=0.0003
INFO | - reprojection_cost(83718): 388351.40625 (avg 2.31940)
INFO | accepted=True ATb_norm=1.05e+05 cost_prev=388351.4062 cost_new=387798.1562
INFO | step #12: cost=387798.1562 lambd=0.0001
INFO | - reprojection_cost(83718): 387798.15625 (avg 2.31610)
INFO | step #13: cost=387798.1562 lambd=0.0003
INFO | - reprojection_cost(83718): 387798.15625 (avg 2.31610)
INFO | step #14: cost=387798.1562 lambd=0.0005
INFO | - reprojection_cost(83718): 387798.15625 (avg 2.31610)
INFO | step #15: cost=387798.1562 lambd=0.0010
INFO | - reprojection_cost(83718): 387798.15625 (avg 2.31610)
INFO | accepted=True ATb_norm=1.04e+05 cost_prev=387798.1562 cost_new=387432.3125
INFO | step #16: cost=387432.3125 lambd=0.0005
INFO | - reprojection_cost(83718): 387432.31250 (avg 2.31391)
INFO | step #17: cost=387432.3125 lambd=0.0010
INFO | - reprojection_cost(83718): 387432.31250 (avg 2.31391)
INFO | accepted=True ATb_norm=5.00e+04 cost_prev=387432.3125 cost_new=387194.6250
INFO | step #18: cost=387194.6250 lambd=0.0005
INFO | - reprojection_cost(83718): 387194.62500 (avg 2.31249)
INFO | accepted=True ATb_norm=3.83e+04 cost_prev=387194.6250 cost_new=387021.4375
INFO | step #19: cost=387021.4375 lambd=0.0003
INFO | - reprojection_cost(83718): 387021.43750 (avg 2.31146)
INFO | step #20: cost=387021.4375 lambd=0.0005
INFO | - reprojection_cost(83718): 387021.43750 (avg 2.31146)
INFO | step #21: cost=387021.4375 lambd=0.0010
INFO | - reprojection_cost(83718): 387021.43750 (avg 2.31146)
INFO | accepted=True ATb_norm=3.88e+04 cost_prev=387021.4375 cost_new=386884.6250
INFO | step #22: cost=386884.6250 lambd=0.0005
INFO | - reprojection_cost(83718): 386884.62500 (avg 2.31064)
INFO | step #23: cost=386884.6250 lambd=0.0010
INFO | - reprojection_cost(83718): 386884.62500 (avg 2.31064)
INFO | step #24: cost=386884.6250 lambd=0.0020
INFO | - reprojection_cost(83718): 386884.62500 (avg 2.31064)
INFO | accepted=True ATb_norm=1.94e+04 cost_prev=386884.6250 cost_new=386786.8750
INFO | step #25: cost=386786.8750 lambd=0.0010
INFO | - reprojection_cost(83718): 386786.87500 (avg 2.31006)
INFO | step #26: cost=386786.8750 lambd=0.0020
INFO | - reprojection_cost(83718): 386786.87500 (avg 2.31006)
INFO | accepted=True ATb_norm=1.65e+04 cost_prev=386786.8750 cost_new=386715.5625
INFO | step #27: cost=386715.5625 lambd=0.0010
INFO | - reprojection_cost(83718): 386715.56250 (avg 2.30963)
INFO | step #28: cost=386715.5625 lambd=0.0020
INFO | - reprojection_cost(83718): 386715.56250 (avg 2.30963)
INFO | step #29: cost=386715.5625 lambd=0.0040
INFO | - reprojection_cost(83718): 386715.56250 (avg 2.30963)
INFO | accepted=True ATb_norm=1.57e+04 cost_prev=386715.5625 cost_new=386657.0000
INFO | Terminated @ iteration #30: cost=386657.0000 criteria=[0 0 1], term_deltas=1.5e-04,6.5e+01,9.1e-07
Reprojection error analysis#
@jax.jit
def compute_reprojection_errors(
vals: jaxls.VarValues,
cam_idx: jax.Array,
pt_idx: jax.Array,
obs: jax.Array,
) -> jax.Array:
"""Compute reprojection errors for all observations.
Args:
vals: Variable values containing camera poses and 3D points
cam_idx: Camera indices for each observation (n_obs,)
pt_idx: Point indices for each observation (n_obs,)
obs: 2D observations (n_obs, 2)
Returns:
Reprojection errors for each observation (n_obs,)
"""
def single_error(
c_idx: jax.Array, p_idx: jax.Array, observed: jax.Array
) -> jax.Array:
pose = vals[jaxls.SE3Var(id=c_idx)]
point = vals[Point3Var(id=p_idx)]
projected = project_point(
point,
pose,
focal_lengths[c_idx],
distortion_k1[c_idx],
distortion_k2[c_idx],
)
return jnp.linalg.norm(projected - observed)
return jax.vmap(single_error)(cam_idx, pt_idx, obs)
errors_initial = compute_reprojection_errors(
initial_vals, camera_indices, point_indices, observations
)
errors_final = compute_reprojection_errors(
solution, camera_indices, point_indices, observations
)
print(
f"Initial reprojection error: mean={float(errors_initial.mean()):.2f}px, median={float(jnp.median(errors_initial)):.2f}px"
)
print(
f"Final reprojection error: mean={float(errors_final.mean()):.2f}px, median={float(jnp.median(errors_final)):.2f}px"
)
Initial reprojection error: mean=31.24px, median=26.71px
Final reprojection error: mean=2.28px, median=1.65px
Visualization#
╭────── viser (listening *:8082) ───────╮ │ ╷ │ │ HTTP │ http://localhost:8082 │ │ Websocket │ ws://localhost:8082 │ │ ╵ │ ╰───────────────────────────────────────╯
Showing 5297 of 22106 points (error < 1.0px)
Bundle adjustment jointly refines camera poses and 3D point positions to minimize reprojection error. The Huber loss provides robustness to outliers, which is important for real-world data.
For solver configuration, see jaxls.TrustRegionConfig. For Lie group variables, see jaxls.SE3Var.