Camera calibration#

In this notebook, we solve a camera calibration problem: estimating intrinsic and extrinsic parameters from checkerboard observations.

Inputs: Images of checkerboard pattern (OpenCV sample data)
Outputs: Camera intrinsics (focal length, principal point, distortion coefficients)

Features used:

  • Var subclass for camera intrinsics

  • SE3Var for camera extrinsic poses

  • @jaxls.Cost.factory for reprojection error

  • OpenCV for chessboard corner detection

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import urllib.request
from pathlib import Path

import cv2
import jax
import jax.numpy as jnp
import jaxls
import jaxlie
import numpy as np
from scipy import ndimage

Download OpenCV sample images#

Download the calibration images from the OpenCV repository:

def download_calibration_images(
    cache_dir: Path = Path("/tmp/opencv_calib"),
) -> list[Path]:
    """Download OpenCV sample calibration images.

    Args:
        cache_dir: Directory to cache downloaded images

    Returns:
        List of paths to downloaded image files
    """
    cache_dir.mkdir(parents=True, exist_ok=True)
    base_url = "https://raw.githubusercontent.com/opencv/opencv/master/samples/data"

    # Note: left10.jpg doesn't exist in the OpenCV repo.
    image_indices = [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14]

    image_paths = []
    for i in image_indices:
        filename = f"left{i:02d}.jpg"
        local_path = cache_dir / filename

        if not local_path.exists():
            url = f"{base_url}/{filename}"
            logger.info(f"Downloading {filename}...")
            urllib.request.urlretrieve(url, local_path)

        image_paths.append(local_path)

    return image_paths


image_paths = download_calibration_images()
print(f"Downloaded {len(image_paths)} calibration images")
Downloaded 13 calibration images

Detect chessboard corners#

Use OpenCV to detect the inner corners of the 9x6 chessboard pattern:

# Chessboard parameters: 9x6 inner corners.
board_cols, board_rows = 9, 6
square_size = 0.025  # 25mm squares (approximate)

# 3D checkerboard points (on Z=0 plane)
board_points_3d = np.zeros((board_rows * board_cols, 3), np.float32)
board_points_3d[:, :2] = (
    np.mgrid[0:board_cols, 0:board_rows].T.reshape(-1, 2) * square_size
)
board_points_3d = jnp.array(board_points_3d)

print(f"Chessboard: {board_cols}x{board_rows} = {len(board_points_3d)} corners")
print(
    f"Board size: {board_cols * square_size * 1000:.0f}mm x {board_rows * square_size * 1000:.0f}mm"
)
Chessboard: 9x6 = 54 corners
Board size: 225mm x 150mm
# Detect corners in all images.
observations_2d: list[jax.Array] = []
valid_image_indices: list[int] = []
image_size = None

criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 30, 0.001)

for i, path in enumerate(image_paths):
    img = cv2.imread(str(path))
    if image_size is None:
        image_size = (img.shape[1], img.shape[0])  # (width, height)

    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    ret, corners = cv2.findChessboardCorners(gray, (board_cols, board_rows), None)

    if ret:
        # Refine corner positions.
        corners_refined = cv2.cornerSubPix(gray, corners, (11, 11), (-1, -1), criteria)
        observations_2d.append(jnp.array(corners_refined.squeeze()))
        valid_image_indices.append(i)
        print(f"  Image {i + 1:2d}: ✓ Found {len(corners)} corners")
    else:
        print(f"  Image {i + 1:2d}: ✗ Chessboard not found")

num_views = len(observations_2d)
print(f"\nSuccessfully detected corners in {num_views}/{len(image_paths)} images")
print(f"Image size: {image_size[0]}x{image_size[1]}")
  Image  1: ✓ Found 54 corners
  Image  2: ✓ Found 54 corners
  Image  3: ✓ Found 54 corners
  Image  4: ✓ Found 54 corners
  Image  5: ✓ Found 54 corners
  Image  6: ✓ Found 54 corners
  Image  7: ✓ Found 54 corners
  Image  8: ✓ Found 54 corners
  Image  9: ✓ Found 54 corners
  Image 10: ✓ Found 54 corners
  Image 11: ✓ Found 54 corners
  Image 12: ✓ Found 54 corners
  Image 13: ✓ Found 54 corners

Successfully detected corners in 13/13 images
Image size: 640x480

Hide code cell source

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import HTML

# Show sample input images with detected corners.
sample_indices = [0, 3, 6]  # Show 3 sample images

fig_samples = make_subplots(
    rows=1,
    cols=len(sample_indices),
    subplot_titles=[f"Image {valid_image_indices[i] + 1}" for i in sample_indices],
)

for col, idx in enumerate(sample_indices):
    img = cv2.imread(str(image_paths[valid_image_indices[idx]]))
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Draw detected corners.
    corners = np.array(observations_2d[idx])
    cv2.drawChessboardCorners(
        img_rgb, (board_cols, board_rows), corners.reshape(-1, 1, 2), True
    )

    fig_samples.add_trace(
        go.Image(z=img_rgb),
        row=1,
        col=col + 1,
    )

fig_samples.update_xaxes(showticklabels=False)
fig_samples.update_yaxes(showticklabels=False)
fig_samples.update_layout(
    height=280,
    margin=dict(t=40, b=20, l=20, r=20),
)
HTML(fig_samples.to_html(full_html=False, include_plotlyjs="cdn"))

Camera model#

We use the Brown-Conrady distortion model (same as OpenCV):

\[x' = x_n (1 + k_1 r^2 + k_2 r^4) + 2 p_1 x_n y_n + p_2 (r^2 + 2 x_n^2)\]
\[y' = y_n (1 + k_1 r^2 + k_2 r^4) + p_1 (r^2 + 2 y_n^2) + 2 p_2 x_n y_n\]
\[u = f_x \cdot x' + c_x, \quad v = f_y \cdot y' + c_y\]

where \((x_n, y_n)\) are normalized coordinates, \(r^2 = x_n^2 + y_n^2\), \((k_1, k_2)\) are radial distortion coefficients, and \((p_1, p_2)\) are tangential distortion coefficients.

class IntrinsicsVar(
    jaxls.Var[jax.Array],
    default_factory=lambda: jnp.array([500.0, 500.0, 320.0, 240.0, 0.0, 0.0, 0.0, 0.0]),
):
    """Camera intrinsics: [fx, fy, cx, cy, k1, k2, p1, p2]."""


@jax.jit
def project_brown_conrady(
    points_camera: jax.Array,  # (N, 3) points in camera frame
    intrinsics: jax.Array,  # [fx, fy, cx, cy, k1, k2, p1, p2]
) -> jax.Array:
    """Project 3D points to 2D using Brown-Conrady distortion model.

    Args:
        points_camera: 3D points in camera frame (N, 3)
        intrinsics: Camera intrinsics [fx, fy, cx, cy, k1, k2, p1, p2]

    Returns:
        2D projected points (N, 2)
    """
    fx, fy, cx, cy, k1, k2, p1, p2 = intrinsics

    x, y, z = points_camera[..., 0], points_camera[..., 1], points_camera[..., 2]

    # Avoid division by zero.
    z_safe = jnp.maximum(z, 1e-6)

    # Normalized coordinates.
    x_n = x / z_safe
    y_n = y / z_safe

    # Radial distortion.
    r2 = x_n**2 + y_n**2
    radial = 1.0 + k1 * r2 + k2 * r2**2

    # Tangential distortion.
    x_d = x_n * radial + 2 * p1 * x_n * y_n + p2 * (r2 + 2 * x_n**2)
    y_d = y_n * radial + p1 * (r2 + 2 * y_n**2) + 2 * p2 * x_n * y_n

    # Pixel coordinates.
    u = fx * x_d + cx
    v = fy * y_d + cy

    return jnp.stack([u, v], axis=-1)

Problem construction#

We optimize camera intrinsics and all extrinsic poses jointly using reprojection error:

# Variables.
intrinsics_var = IntrinsicsVar(id=0)
pose_vars = [jaxls.SE3Var(id=i) for i in range(num_views)]
@jaxls.Cost.factory
def reprojection_cost(
    vals: jaxls.VarValues,
    intrinsics_var: IntrinsicsVar,
    pose_var: jaxls.SE3Var,
    points_3d: jax.Array,  # (N, 3) 3D points
    observed_2d: jax.Array,  # (N, 2) observed 2D points
) -> jax.Array:
    """Compute reprojection error for a view.

    Args:
        vals: Variable values
        intrinsics_var: Camera intrinsics variable
        pose_var: Camera pose variable (world-to-camera)
        points_3d: 3D world points (N, 3)
        observed_2d: Observed 2D points (N, 2)

    Returns:
        Reprojection error residual (N*2,)
    """
    intrinsics = vals[intrinsics_var]
    pose = vals[pose_var]

    # Transform points to camera frame.
    points_camera = jax.vmap(pose.apply)(points_3d)

    # Project to image plane.
    projected = project_brown_conrady(points_camera, intrinsics)

    # Return reprojection error.
    return (projected - observed_2d).flatten()
# Build costs using batched construction - one cost per view.
costs: list[jaxls.Cost] = [
    reprojection_cost(
        intrinsics_var,
        pose_vars[view_idx],
        board_points_3d,  # All 3D points
        observations_2d[view_idx],  # All 2D observations for this view
    )
    for view_idx in range(num_views)
]

print(f"Created {len(costs)} batched reprojection costs ({num_views} views)")
Created 13 batched reprojection costs (13 views)
# Initialize intrinsics with reasonable guesses.
# Focal length ~ image width, principal point ~ image center.
init_fx = float(image_size[0]) / 2
init_fy = float(image_size[0]) / 2
init_cx = float(image_size[0]) / 2
init_cy = float(image_size[1]) / 2
init_intrinsics = jnp.array([init_fx, init_fy, init_cx, init_cy, 0.0, 0.0, 0.0, 0.0])

print(
    f"Initial intrinsics: fx={init_fx:.0f}, fy={init_fy:.0f}, cx={init_cx:.0f}, cy={init_cy:.0f}"
)
Initial intrinsics: fx=320, fy=320, cx=320, cy=240
def estimate_initial_pose(
    observed_corners: jax.Array, intrinsics: jax.Array
) -> jaxlie.SE3:
    """Estimate initial pose using OpenCV's solvePnP.

    Args:
        observed_corners: Detected 2D corner positions (N, 2)
        intrinsics: Camera intrinsics [fx, fy, cx, cy, k1, k2, p1, p2]

    Returns:
        Estimated camera pose as SE3
    """
    fx, fy, cx, cy, k1, k2, p1, p2 = intrinsics
    camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float64)
    dist_coeffs = np.array([k1, k2, p1, p2], dtype=np.float64)

    _, rvec, tvec = cv2.solvePnP(
        np.array(board_points_3d),
        np.array(observed_corners),
        camera_matrix,
        dist_coeffs,
    )

    R, _ = cv2.Rodrigues(rvec)
    rotation = jaxlie.SO3.from_matrix(jnp.array(R))
    translation = jnp.array(tvec.squeeze())

    return jaxlie.SE3.from_rotation_and_translation(rotation, translation)


# Estimate initial poses.
init_poses = [estimate_initial_pose(obs, init_intrinsics) for obs in observations_2d]
print(f"Estimated {len(init_poses)} initial poses using PnP")
Estimated 13 initial poses using PnP
# Create initial values.
initial_vals = jaxls.VarValues.make(
    [intrinsics_var.with_value(init_intrinsics)]
    + [pose_vars[i].with_value(init_poses[i]) for i in range(num_views)]
)

# Create the problem.
problem = jaxls.LeastSquaresProblem(costs, [intrinsics_var] + pose_vars)

# Visualize the problem structure structure.
problem.show()
# Analyze the problem.
problem = problem.analyze()
INFO     | Building optimization problem with 13 terms and 14 variables: 13 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 13 costs, 2 variables each: reprojection_cost

Solving#

solution = problem.solve(initial_vals)
INFO     |  step #0: cost=11776.8213 lambd=0.0005 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 11776.82129 (avg 8.38805)
INFO     |      accepted=True ATb_norm=3.08e+04 cost_prev=11776.8213 cost_new=118103.0781
INFO     |  step #1: cost=118103.0781 lambd=0.0003 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 118103.07812 (avg 84.11900)
INFO     |      accepted=True ATb_norm=4.23e+06 cost_prev=118103.0781 cost_new=630.2039
INFO     |  step #2: cost=630.2039 lambd=0.0001 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 630.20386 (avg 0.44886)
INFO     |  step #3: cost=630.2039 lambd=0.0003 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 630.20386 (avg 0.44886)
INFO     |  step #4: cost=630.2039 lambd=0.0005 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 630.20386 (avg 0.44886)
INFO     |  step #5: cost=630.2039 lambd=0.0010 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 630.20386 (avg 0.44886)
INFO     |  step #6: cost=630.2039 lambd=0.0020 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 630.20386 (avg 0.44886)
INFO     |  step #7: cost=630.2039 lambd=0.0040 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 630.20386 (avg 0.44886)
INFO     |  step #8: cost=630.2039 lambd=0.0080 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 630.20386 (avg 0.44886)
INFO     |  step #9: cost=630.2039 lambd=0.0160 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 630.20386 (avg 0.44886)
INFO     |      accepted=True ATb_norm=9.59e+04 cost_prev=630.2039 cost_new=236.5836
INFO     |  step #10: cost=236.5836 lambd=0.0080 inexact_tol=4.6e-04
INFO     |      - reprojection_cost(13): 236.58362 (avg 0.16851)
INFO     |      accepted=True ATb_norm=1.02e+05 cost_prev=236.5836 cost_new=117.5282
INFO     |  step #11: cost=117.5282 lambd=0.0040 inexact_tol=4.6e-04
INFO     |      - reprojection_cost(13): 117.52823 (avg 0.08371)
INFO     |  step #12: cost=117.5282 lambd=0.0080 inexact_tol=4.6e-04
INFO     |      - reprojection_cost(13): 117.52823 (avg 0.08371)
INFO     |  step #13: cost=117.5282 lambd=0.0160 inexact_tol=4.6e-04
INFO     |      - reprojection_cost(13): 117.52823 (avg 0.08371)
INFO     |  step #14: cost=117.5282 lambd=0.0320 inexact_tol=4.6e-04
INFO     |      - reprojection_cost(13): 117.52823 (avg 0.08371)
INFO     |  step #15: cost=117.5282 lambd=0.0640 inexact_tol=4.6e-04
INFO     |      - reprojection_cost(13): 117.52823 (avg 0.08371)
INFO     |  step #16: cost=117.5282 lambd=0.1280 inexact_tol=4.6e-04
INFO     |      - reprojection_cost(13): 117.52823 (avg 0.08371)
INFO     |      accepted=True ATb_norm=1.61e+03 cost_prev=117.5282 cost_new=117.4151
INFO     |  step #17: cost=117.4151 lambd=0.0640 inexact_tol=2.2e-04
INFO     |      - reprojection_cost(13): 117.41513 (avg 0.08363)
INFO     |  step #18: cost=117.4151 lambd=0.1280 inexact_tol=2.2e-04
INFO     |      - reprojection_cost(13): 117.41513 (avg 0.08363)
INFO     |  step #19: cost=117.4151 lambd=0.2560 inexact_tol=2.2e-04
INFO     |      - reprojection_cost(13): 117.41513 (avg 0.08363)
INFO     |  step #20: cost=117.4151 lambd=0.5120 inexact_tol=2.2e-04
INFO     |      - reprojection_cost(13): 117.41513 (avg 0.08363)
INFO     |      accepted=True ATb_norm=1.46e+01 cost_prev=117.4151 cost_new=117.4097
INFO     |  step #21: cost=117.4097 lambd=0.2560 inexact_tol=7.5e-05
INFO     |      - reprojection_cost(13): 117.40974 (avg 0.08363)
INFO     |  step #22: cost=117.4097 lambd=0.5120 inexact_tol=7.5e-05
INFO     |      - reprojection_cost(13): 117.40974 (avg 0.08363)
INFO     |      accepted=True ATb_norm=7.08e+00 cost_prev=117.4097 cost_new=117.4080
INFO     |  step #23: cost=117.4080 lambd=0.2560 inexact_tol=7.5e-05
INFO     |      - reprojection_cost(13): 117.40798 (avg 0.08362)
INFO     |  step #24: cost=117.4080 lambd=0.5120 inexact_tol=7.5e-05
INFO     |      - reprojection_cost(13): 117.40798 (avg 0.08362)
INFO     |      accepted=True ATb_norm=8.25e+00 cost_prev=117.4080 cost_new=117.4064
INFO     |  step #25: cost=117.4064 lambd=0.2560 inexact_tol=7.5e-05
INFO     |      - reprojection_cost(13): 117.40639 (avg 0.08362)
INFO     |  step #26: cost=117.4064 lambd=0.5120 inexact_tol=7.5e-05
INFO     |      - reprojection_cost(13): 117.40639 (avg 0.08362)
INFO     |      accepted=True ATb_norm=5.53e+00 cost_prev=117.4064 cost_new=117.4042
INFO     |  step #27: cost=117.4042 lambd=0.2560 inexact_tol=7.5e-05
INFO     |      - reprojection_cost(13): 117.40421 (avg 0.08362)
INFO     |  step #28: cost=117.4042 lambd=0.5120 inexact_tol=7.5e-05
INFO     |      - reprojection_cost(13): 117.40421 (avg 0.08362)
INFO     |      accepted=True ATb_norm=5.43e+00 cost_prev=117.4042 cost_new=117.4032
INFO     | Terminated @ iteration #29: cost=117.4032 criteria=[1 0 0], term_deltas=8.9e-06,2.9e+00,3.7e-05
# Compare results.
est_intrinsics = solution[intrinsics_var]

print("Estimated intrinsics:")
param_names = ["fx", "fy", "cx", "cy", "k1", "k2", "p1", "p2"]
print(f"  {'Parameter':<12} {'Initial':>12} {'Estimated':>12}")
print(f"  {'-' * 38}")
for i, name in enumerate(param_names):
    init, est = init_intrinsics[i], est_intrinsics[i]
    print(f"  {name:<12} {float(init):>12.4f} {float(est):>12.4f}")
Estimated intrinsics:
  Parameter         Initial    Estimated
  --------------------------------------
  fx               320.0000     536.4322
  fy               320.0000     536.3876
  cx               320.0000     342.2786
  cy               240.0000     235.6965
  k1                 0.0000      -0.2786
  k2                 0.0000       0.0673
  p1                 0.0000       0.0018
  p2                 0.0000      -0.0003

Visualization#

Hide code cell source

def compute_reprojection_errors(
    intrinsics: jax.Array, poses: list[jaxlie.SE3]
) -> tuple[list[jax.Array], list[jax.Array]]:
    """Compute reprojection errors for all views.

    Args:
        intrinsics: Camera intrinsics [fx, fy, cx, cy, k1, k2, p1, p2]
        poses: List of camera poses (one per view)

    Returns:
        Tuple of (projected_points, errors) where each is a list per view
    """
    all_projected = []
    all_errors = []
    for i, pose in enumerate(poses):
        points_camera = jax.vmap(pose.apply)(board_points_3d)
        projected = project_brown_conrady(points_camera, intrinsics)
        errors = jnp.linalg.norm(projected - observations_2d[i], axis=-1)
        all_projected.append(projected)
        all_errors.append(errors)
    return all_projected, all_errors


# Compute errors before and after.
init_projected, init_errors = compute_reprojection_errors(init_intrinsics, init_poses)
est_poses = [solution[pose_vars[i]] for i in range(num_views)]
est_projected, est_errors = compute_reprojection_errors(est_intrinsics, est_poses)

init_rmse = float(jnp.sqrt(jnp.mean(jnp.concatenate([e**2 for e in init_errors]))))
est_rmse = float(jnp.sqrt(jnp.mean(jnp.concatenate([e**2 for e in est_errors]))))
print(
    f"Reprojection RMSE: {init_rmse:.3f} px (initial) -> {est_rmse:.3f} px (optimized)"
)
Reprojection RMSE: 4.096 px (initial) -> 0.409 px (optimized)

Hide code cell source

# Reprojection error distribution.
all_init_errors = jnp.concatenate(init_errors)
all_est_errors = jnp.concatenate(est_errors)

fig_errors = go.Figure()
fig_errors.add_trace(
    go.Histogram(
        x=all_init_errors,
        name="Initial",
        marker_color="tomato",
        opacity=0.7,
        nbinsx=30,
    )
)
fig_errors.add_trace(
    go.Histogram(
        x=all_est_errors,
        name="Optimized",
        marker_color="steelblue",
        opacity=0.7,
        nbinsx=30,
    )
)
fig_errors.update_layout(
    title="Reprojection Error Distribution",
    xaxis_title="Error (pixels)",
    yaxis_title="Count",
    barmode="overlay",
    height=300,
    margin=dict(t=40, b=40, l=60, r=40),
    legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
)
HTML(fig_errors.to_html(full_html=False, include_plotlyjs="cdn"))

Hide code cell source

# Camera poses (top-down view)
cam_positions = [pose.inverse().translation() for pose in est_poses]
cam_x = [float(p[0]) for p in cam_positions]
cam_y = [float(p[1]) for p in cam_positions]

# Chessboard outline.
board_corners = jnp.array(
    [
        [0, 0, 0],
        [board_cols * square_size, 0, 0],
        [board_cols * square_size, board_rows * square_size, 0],
        [0, board_rows * square_size, 0],
        [0, 0, 0],
    ]
)

fig_poses = go.Figure()
fig_poses.add_trace(
    go.Scatter(
        x=board_corners[:, 0] * 1000,
        y=board_corners[:, 1] * 1000,
        mode="lines",
        line=dict(color="gray", width=2),
        name="Board",
    )
)
fig_poses.add_trace(
    go.Scatter(
        x=[c * 1000 for c in cam_x],
        y=[c * 1000 for c in cam_y],
        mode="markers+text",
        marker=dict(size=10, color="steelblue"),
        text=[str(i + 1) for i in range(num_views)],
        textposition="top center",
        name="Cameras",
    )
)
fig_poses.update_layout(
    title="Camera Poses (top view)",
    xaxis_title="X (mm)",
    yaxis_title="Y (mm)",
    xaxis=dict(scaleanchor="y", scaleratio=1),
    height=350,
    margin=dict(t=40, b=40, l=60, r=40),
    showlegend=False,
)
HTML(fig_poses.to_html(full_html=False, include_plotlyjs="cdn"))

Hide code cell source

# Single view comparison: initial vs optimized reprojection.
view_idx = 1
obs = observations_2d[view_idx]

fig_view = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=(
        f"Initial (RMSE={float(jnp.sqrt(jnp.mean(init_errors[view_idx] ** 2))):.2f}px)",
        f"Optimized (RMSE={float(jnp.sqrt(jnp.mean(est_errors[view_idx] ** 2))):.2f}px)",
    ),
)

# Initial projection.
fig_view.add_trace(
    go.Scatter(
        x=obs[:, 0],
        y=obs[:, 1],
        mode="markers",
        marker=dict(size=8, color="green", symbol="circle"),
        name="Observed",
        showlegend=True,
    ),
    row=1,
    col=1,
)
fig_view.add_trace(
    go.Scatter(
        x=init_projected[view_idx][:, 0],
        y=init_projected[view_idx][:, 1],
        mode="markers",
        marker=dict(size=6, color="tomato", symbol="x"),
        name="Projected",
        showlegend=True,
    ),
    row=1,
    col=1,
)
for j in range(len(obs)):
    fig_view.add_trace(
        go.Scatter(
            x=[obs[j, 0], init_projected[view_idx][j, 0]],
            y=[obs[j, 1], init_projected[view_idx][j, 1]],
            mode="lines",
            line=dict(color="tomato", width=0.5),
            showlegend=False,
            hoverinfo="skip",
        ),
        row=1,
        col=1,
    )

# Optimized projection.
fig_view.add_trace(
    go.Scatter(
        x=obs[:, 0],
        y=obs[:, 1],
        mode="markers",
        marker=dict(size=8, color="green", symbol="circle"),
        showlegend=False,
    ),
    row=1,
    col=2,
)
fig_view.add_trace(
    go.Scatter(
        x=est_projected[view_idx][:, 0],
        y=est_projected[view_idx][:, 1],
        mode="markers",
        marker=dict(size=6, color="steelblue", symbol="x"),
        showlegend=False,
    ),
    row=1,
    col=2,
)
for j in range(len(obs)):
    fig_view.add_trace(
        go.Scatter(
            x=[obs[j, 0], est_projected[view_idx][j, 0]],
            y=[obs[j, 1], est_projected[view_idx][j, 1]],
            mode="lines",
            line=dict(color="steelblue", width=0.5),
            showlegend=False,
            hoverinfo="skip",
        ),
        row=1,
        col=2,
    )

fig_view.update_xaxes(title_text="u (pixels)")
fig_view.update_yaxes(title_text="v (pixels)", autorange="reversed")
fig_view.update_layout(
    height=400,
    margin=dict(t=40, b=80, l=60, r=40),
    legend=dict(orientation="h", yanchor="top", y=-0.15, xanchor="center", x=0.5),
)
HTML(fig_view.to_html(full_html=False, include_plotlyjs="cdn"))

Undistortion#

Apply the estimated distortion parameters to rectify the images:

Hide code cell source

def undistort_image(img: np.ndarray, intrinsics: jax.Array) -> np.ndarray:
    """Undistort an image using the estimated intrinsics and scipy.ndimage.map_coordinates.

    Args:
        img: Input image (H, W, 3) or (H, W)
        intrinsics: Camera intrinsics [fx, fy, cx, cy, k1, k2, p1, p2]

    Returns:
        Undistorted image with same shape as input
    """
    fx, fy, cx, cy, k1, k2, p1, p2 = [float(x) for x in intrinsics]
    h, w = img.shape[:2]

    # Create grid of output pixel coordinates (in undistorted image)
    u, v = np.meshgrid(np.arange(w), np.arange(h))

    # Convert to undistorted normalized coordinates.
    x_n = (u - cx) / fx
    y_n = (v - cy) / fy

    # Apply forward distortion to find where to sample from in the distorted input.
    r2 = x_n**2 + y_n**2
    radial = 1.0 + k1 * r2 + k2 * r2**2
    dx_t = 2 * p1 * x_n * y_n + p2 * (r2 + 2 * x_n**2)
    dy_t = p1 * (r2 + 2 * y_n**2) + 2 * p2 * x_n * y_n
    x_d = x_n * radial + dx_t
    y_d = y_n * radial + dy_t

    # Convert to pixel coordinates in the distorted input image.
    u_src = fx * x_d + cx
    v_src = fy * y_d + cy

    # Sample from source image using map_coordinates.
    if len(img.shape) == 3:
        # Color image - process each channel.
        undistorted = np.zeros_like(img)
        for c in range(3):
            undistorted[:, :, c] = ndimage.map_coordinates(
                img[:, :, c], [v_src, u_src], order=1, mode="constant", cval=0
            )
    else:
        undistorted = ndimage.map_coordinates(
            img, [v_src, u_src], order=1, mode="constant", cval=0
        )

    return undistorted


def compute_distortion_at_points(
    intrinsics: jax.Array, points: jax.Array
) -> np.ndarray:
    """Compute distortion magnitude at specific pixel locations.

    Args:
        intrinsics: Camera intrinsics [fx, fy, cx, cy, k1, k2, p1, p2]
        points: Pixel coordinates (N, 2)

    Returns:
        Distortion magnitude at each point (N,)
    """
    fx, fy, cx, cy, k1, k2, p1, p2 = [float(x) for x in intrinsics]
    u, v = points[:, 0], points[:, 1]

    x_n = (u - cx) / fx
    y_n = (v - cy) / fy
    r2 = x_n**2 + y_n**2
    radial = 1.0 + k1 * r2 + k2 * r2**2
    x_d = x_n * radial
    y_d = y_n * radial
    u_d = fx * x_d + cx
    v_d = fy * y_d + cy

    return np.sqrt((u_d - u) ** 2 + (v_d - v) ** 2)


def compute_distortion_magnitude(
    intrinsics: jax.Array, shape: tuple[int, int]
) -> np.ndarray:
    """Compute per-pixel distortion magnitude in pixels.

    Args:
        intrinsics: Camera intrinsics [fx, fy, cx, cy, k1, k2, p1, p2]
        shape: Image shape (height, width)

    Returns:
        Distortion magnitude map (H, W) in pixels
    """
    fx, fy, cx, cy, k1, k2, p1, p2 = [float(x) for x in intrinsics]
    h, w = shape

    u, v = np.meshgrid(np.arange(w), np.arange(h))

    # Normalized coordinates (undistorted)
    x_n = (u - cx) / fx
    y_n = (v - cy) / fy

    # Apply distortion.
    r2 = x_n**2 + y_n**2
    radial = 1.0 + k1 * r2 + k2 * r2**2
    dx_t = 2 * p1 * x_n * y_n + p2 * (r2 + 2 * x_n**2)
    dy_t = p1 * (r2 + 2 * y_n**2) + 2 * p2 * x_n * y_n

    x_d = x_n * radial + dx_t
    y_d = y_n * radial + dy_t

    # Convert back to pixels.
    u_d = fx * x_d + cx
    v_d = fy * y_d + cy

    # Displacement magnitude.
    return np.sqrt((u_d - u) ** 2 + (v_d - v) ** 2)


# Show original vs undistorted + distortion map.
sample_idx = 1
img_orig = cv2.imread(str(image_paths[valid_image_indices[sample_idx]]))
img_orig_rgb = cv2.cvtColor(img_orig, cv2.COLOR_BGR2RGB)
img_undist = undistort_image(img_orig_rgb, est_intrinsics)
distortion_map = compute_distortion_magnitude(est_intrinsics, img_orig.shape[:2])

# Compute distortion at observation locations for this view.
obs_distortion = compute_distortion_at_points(
    est_intrinsics, observations_2d[sample_idx]
)

fig_undist = make_subplots(
    rows=1,
    cols=3,
    subplot_titles=(
        "Original",
        "Undistorted",
        f"Distortion map (corners: {obs_distortion.min():.1f}-{obs_distortion.max():.1f}px)",
    ),
)

fig_undist.add_trace(go.Image(z=img_orig_rgb), row=1, col=1)
fig_undist.add_trace(go.Image(z=img_undist), row=1, col=2)
fig_undist.add_trace(
    go.Heatmap(
        z=distortion_map,
        colorscale="Hot",
        showscale=True,
        colorbar=dict(title="px", len=0.8, x=1.02),
    ),
    row=1,
    col=3,
)
# Overlay observation locations on distortion map.
fig_undist.add_trace(
    go.Scatter(
        x=observations_2d[sample_idx][:, 0],
        y=observations_2d[sample_idx][:, 1],
        mode="markers",
        marker=dict(size=4, color="cyan", symbol="circle"),
        showlegend=False,
        hovertemplate="Distortion: %{text:.1f}px<extra></extra>",
        text=obs_distortion,
    ),
    row=1,
    col=3,
)

fig_undist.update_xaxes(showticklabels=False)
fig_undist.update_yaxes(showticklabels=False, autorange="reversed", row=1, col=3)
fig_undist.update_layout(
    height=280,
    margin=dict(t=40, b=20, l=20, r=40),
)
HTML(fig_undist.to_html(full_html=False, include_plotlyjs="cdn"))

The optimization calibrated the camera from checkerboard images:

  • Top-left: Reprojection error distribution before (red) and after (blue) optimization

  • Top-right: Top-down view of estimated camera positions relative to the chessboard

  • Bottom: Single view comparison showing observed corners (green) vs projected (x markers)

For more details, see: