IMU fusion + calibration#

In this notebook, we solve a sensor fusion problem: combining IMU measurements with sparse position fixes to estimate poses, velocities, and sensor biases.

Inputs: High-rate accelerometer/gyroscope data, occasional GPS-like position priors
Outputs: Trajectory (poses + velocities) and calibrated IMU biases

Features used:

  • SE3Var for 6-DOF poses

  • Var subclassing for velocity and IMU bias variables

  • @jaxls.Cost.factory for preintegration costs

  • Batched construction for efficient problem building

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import jax
import jax.numpy as jnp
import jaxlie
import jaxls
import numpy as np

Custom variables#

In addition to SE3 poses, we need:

  • Velocity: 3D linear velocity in world frame

  • IMU Bias: 6D vector (3 accelerometer + 3 gyroscope biases)

class VelocityVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(3)):
    """3D velocity in world frame."""


class BiasVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(6)):
    """IMU bias: [accel_bias (3), gyro_bias (3)]."""

Generate synthetic IMU data#

We simulate a vehicle moving in a circular trajectory. The ground truth trajectory gives us:

  • True poses, velocities at each keyframe

  • Synthetic IMU measurements (with noise and bias) between keyframes

# Simulation parameters.
n_keyframes = 40
imu_rate = 100  # Hz
keyframe_dt = 0.5  # seconds between keyframes
imu_measurements_per_keyframe = int(imu_rate * keyframe_dt)
imu_dt = keyframe_dt / imu_measurements_per_keyframe

# True IMU biases (constant) - larger values for clear demonstration.
true_accel_bias = jnp.array([0.3, -0.2, 0.15])  # m/s^2
true_gyro_bias = jnp.array([0.02, -0.01, 0.005])  # rad/s

true_bias = jnp.concatenate([true_accel_bias, true_gyro_bias])

# Gravity in world frame.
gravity = jnp.array([0.0, 0.0, -9.81])

# Noise parameters (low to make bias estimation cleaner)
accel_noise_std = 0.01
gyro_noise_std = 0.001

print(f"Keyframes: {n_keyframes}")
print(f"Total time: {n_keyframes * keyframe_dt:.1f}s")
print(f"IMU measurements per interval: {imu_measurements_per_keyframe}")
print(f"True biases: accel={true_accel_bias}, gyro={true_gyro_bias}")
Keyframes: 40
Total time: 20.0s
IMU measurements per interval: 50
True biases: accel=[ 0.3  -0.2   0.15], gyro=[ 0.02  -0.01   0.005]
# Helper for displaying orientation variation.
def so3_to_rpy(rot: jaxlie.SO3) -> jax.Array:
    """Extract roll, pitch, yaw from SO3."""
    R = rot.as_matrix()
    pitch = jnp.arcsin(-R[2, 0])
    roll = jnp.arctan2(R[2, 1], R[2, 2])
    yaw = jnp.arctan2(R[1, 0], R[0, 0])
    return jnp.array([roll, pitch, yaw])
def trajectory_state(
    t: jax.Array,
    scale: float = 5.0,
    height_amplitude: float = 3.0,
    base_height: float = 5.0,
    period: float = 10.0,
) -> tuple[jax.Array, jax.Array]:
    """Compute position and Euler angles at time t.

    Returns (position, euler_angles) where euler_angles = [roll, pitch, yaw].
    """
    theta = 2 * jnp.pi * t / period
    dtheta = 2 * jnp.pi / period

    # Position.
    x = scale * jnp.sin(theta)
    y = scale * jnp.sin(theta) * jnp.cos(theta)
    z = base_height + height_amplitude * jnp.sin(2 * theta)
    position = jnp.array([x, y, z])

    # Velocity (needed for yaw and pitch computation)
    vx = scale * jnp.cos(theta) * dtheta
    vy = scale * jnp.cos(2 * theta) * dtheta
    vz = 2 * height_amplitude * jnp.cos(2 * theta) * dtheta

    # Euler angles.
    roll = 0.5 * jnp.sin(2 * theta)
    yaw = jnp.arctan2(vy, vx)
    v_horiz = jnp.sqrt(vx**2 + vy**2)
    pitch = jnp.arctan2(vz, v_horiz)
    euler = jnp.array([roll, pitch, yaw])

    return position, euler


def generate_trajectory_at_times(
    times: jax.Array,
    scale: float = 5.0,
    height_amplitude: float = 3.0,
    base_height: float = 5.0,
    period: float = 10.0,
) -> tuple[jaxlie.SE3, jax.Array, jax.Array, jax.Array]:
    """Generate trajectory using autodiff for velocities/accelerations.

    Returns poses, velocities, accelerations, and body-frame angular velocities.
    """

    def state_at_t(t: jax.Array) -> tuple[jax.Array, jax.Array]:
        return trajectory_state(t, scale, height_amplitude, base_height, period)

    # Use autodiff to compute derivatives.
    def position_at_t(t: jax.Array) -> jax.Array:
        return state_at_t(t)[0]

    def euler_at_t(t: jax.Array) -> jax.Array:
        return state_at_t(t)[1]

    # Velocity = d(position)/dt, Acceleration = d²(position)/dt².
    velocity_fn = jax.jacfwd(position_at_t)
    accel_fn = jax.jacfwd(velocity_fn)

    # Euler angle rates.
    euler_dot_fn = jax.jacfwd(euler_at_t)

    def compute_state(
        t: jax.Array,
    ) -> tuple[jaxlie.SE3, jax.Array, jax.Array, jax.Array]:
        position, euler = state_at_t(t)
        velocity = velocity_fn(t)
        acceleration = accel_fn(t)
        euler_dot = euler_dot_fn(t)

        roll, pitch, yaw = euler
        roll_dot, pitch_dot, yaw_dot = euler_dot

        # Body angular velocity from Euler angle rates (ZYX convention)
        omega_x = roll_dot - yaw_dot * jnp.sin(pitch)
        omega_y = pitch_dot * jnp.cos(roll) + yaw_dot * jnp.cos(pitch) * jnp.sin(roll)
        omega_z = -pitch_dot * jnp.sin(roll) + yaw_dot * jnp.cos(pitch) * jnp.cos(roll)
        omega_body = jnp.array([omega_x, omega_y, omega_z])

        # Build pose from Euler angles.
        rot = (
            jaxlie.SO3.from_z_radians(yaw)
            @ jaxlie.SO3.from_y_radians(pitch)
            @ jaxlie.SO3.from_x_radians(roll)
        )
        pose = jaxlie.SE3.from_rotation_and_translation(rot, position)

        return pose, velocity, acceleration, omega_body

    # Vectorize over all times.
    poses, velocities, accelerations, angular_velocities = jax.vmap(compute_state)(
        times
    )

    return poses, velocities, accelerations, angular_velocities


def generate_imu_measurements(
    accel_bias: jax.Array,
    gyro_bias: jax.Array,
    key: jax.Array,
) -> tuple[jax.Array, jax.Array]:
    """Generate smooth IMU measurements by sampling the analytic trajectory.

    Args:
        accel_bias: Accelerometer bias (3,)
        gyro_bias: Gyroscope bias (3,)
        key: JAX random key

    Returns:
        Tuple of (accel_measurements, gyro_measurements) each with shape
        (n_keyframes-1, imu_measurements_per_keyframe, 3)
    """
    n_intervals = n_keyframes - 1

    # Generate times for all IMU samples (centered in each dt interval)
    interval_starts = jnp.arange(n_intervals) * keyframe_dt
    imu_offsets = (jnp.arange(imu_measurements_per_keyframe) + 0.5) * imu_dt
    times_imu = (interval_starts[:, None] + imu_offsets[None, :]).flatten()

    # Get trajectory at IMU times using autodiff.
    poses_imu, _, accelerations_imu, omega_body = generate_trajectory_at_times(
        times_imu
    )

    # Accelerometer measures specific force: a_body = R^T @ (a_world - g)
    rotations = poses_imu.rotation()
    specific_force_world = accelerations_imu - gravity[None, :]
    accel_body = jax.vmap(lambda R, a: R.inverse() @ a)(rotations, specific_force_world)

    # Gyroscope measures angular velocity in body frame.
    gyro_body = omega_body

    # Reshape to (n_intervals, imu_measurements_per_keyframe, 3)
    accel_body = accel_body.reshape(n_intervals, imu_measurements_per_keyframe, 3)
    gyro_body = gyro_body.reshape(n_intervals, imu_measurements_per_keyframe, 3)

    # Add bias and noise.
    keys = jax.random.split(key, 2)
    accel_noise = jax.random.normal(keys[0], accel_body.shape) * accel_noise_std
    gyro_noise = jax.random.normal(keys[1], gyro_body.shape) * gyro_noise_std

    accels = accel_body + accel_bias[None, None, :] + accel_noise
    gyros = gyro_body + gyro_bias[None, None, :] + gyro_noise

    return accels, gyros


# Generate ground truth trajectory at keyframe times.
keyframe_times = jnp.arange(n_keyframes) * keyframe_dt
true_poses, true_velocities, _, _ = generate_trajectory_at_times(keyframe_times)

# Generate IMU measurements.
key = jax.random.PRNGKey(42)
accel_measurements, gyro_measurements = generate_imu_measurements(
    true_accel_bias, true_gyro_bias, key
)

print(f"Generated {n_keyframes} keyframe poses")
print(f"IMU measurements shape: {accel_measurements.shape}")
Generated 40 keyframe poses
IMU measurements shape: (39, 50, 3)

The generated IMU measurements include accelerometer (linear acceleration in body frame) and gyroscope (angular velocity) readings. Vertical dotted lines indicate keyframe boundaries where preintegration intervals begin/end:

Hide code cell source

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import HTML

# Visualize raw IMU measurements.
# Create time array for IMU measurements (vectorized)
interval_starts = jnp.arange(n_keyframes - 1) * keyframe_dt
imu_offsets = jnp.arange(imu_measurements_per_keyframe) * imu_dt
times_imu = (interval_starts[:, None] + imu_offsets[None, :]).flatten()

accel_flat = accel_measurements.reshape(-1, 3)
gyro_flat = gyro_measurements.reshape(-1, 3)

fig_imu = make_subplots(
    rows=2,
    cols=1,
    subplot_titles=("Accelerometer Measurements", "Gyroscope Measurements"),
    vertical_spacing=0.12,
)

# Accelerometer.
colors = ["steelblue", "coral", "forestgreen"]
labels = ["x", "y", "z"]
for i, (color, label) in enumerate(zip(colors, labels)):
    fig_imu.add_trace(
        go.Scatter(
            x=np.array(times_imu),
            y=np.array(accel_flat[:, i]),
            mode="lines",
            line=dict(color=color, width=1),
            name=f"accel_{label}",
            showlegend=True,
        ),
        row=1,
        col=1,
    )

# Gyroscope.
for i, (color, label) in enumerate(zip(colors, labels)):
    fig_imu.add_trace(
        go.Scatter(
            x=np.array(times_imu),
            y=np.array(gyro_flat[:, i]),
            mode="lines",
            line=dict(color=color, width=1),
            name=f"gyro_{label}",
            showlegend=True,
        ),
        row=2,
        col=1,
    )

# Add keyframe markers.
keyframe_times = np.arange(n_keyframes) * keyframe_dt
for t in keyframe_times:
    fig_imu.add_vline(
        x=t, line_dash="dot", line_color="gray", opacity=0.3, row=1, col=1
    )
    fig_imu.add_vline(
        x=t, line_dash="dot", line_color="gray", opacity=0.3, row=2, col=1
    )

fig_imu.update_xaxes(title_text="Time (s)", row=2, col=1)
fig_imu.update_yaxes(title_text="Acceleration (m/s²)", row=1, col=1)
fig_imu.update_yaxes(title_text="Angular velocity (rad/s)", row=2, col=1)

fig_imu.update_layout(
    height=400,
    margin=dict(t=40, b=40, l=60, r=40),
    legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5),
)

HTML(fig_imu.to_html(full_html=False, include_plotlyjs="cdn"))

IMU preintegration#

Preintegration combines many high-frequency IMU measurements into a single relative motion constraint. Given IMU measurements between times \(i\) and \(j\), we compute:

  • \(\Delta R_{ij}\): Relative rotation

  • \(\Delta v_{ij}\): Velocity change (in frame \(i\))

  • \(\Delta p_{ij}\): Position change (in frame \(i\))

These preintegrated measurements are independent of the absolute pose/velocity at time \(i\), which allows efficient re-linearization during optimization.

def preintegrate_imu(
    accel_meas: jax.Array,  # (n_measurements, 3)
    gyro_meas: jax.Array,  # (n_measurements, 3)
    accel_bias: jax.Array,  # (3,)
    gyro_bias: jax.Array,  # (3,)
    dt: float,
) -> tuple[jaxlie.SO3, jax.Array, jax.Array]:
    """Preintegrate IMU measurements between two keyframes.

    Args:
        accel_meas: Accelerometer measurements (n_measurements, 3)
        gyro_meas: Gyroscope measurements (n_measurements, 3)
        accel_bias: Accelerometer bias (3,)
        gyro_bias: Gyroscope bias (3,)
        dt: Time step between measurements

    Returns:
        Tuple of (delta_R, delta_v, delta_p):
            delta_R: Relative rotation (SO3)
            delta_v: Velocity increment in body frame (3,)
            delta_p: Position increment in body frame (3,)
    """
    # Bias-corrected measurements.
    accel_corrected = accel_meas - accel_bias[None, :]
    gyro_corrected = gyro_meas - gyro_bias[None, :]

    # Initialize preintegrated values.
    delta_R = jaxlie.SO3.identity()
    delta_v = jnp.zeros(3)
    delta_p = jnp.zeros(3)

    def step(
        carry: tuple[jaxlie.SO3, jax.Array, jax.Array],
        inputs: tuple[jax.Array, jax.Array],
    ) -> tuple[tuple[jaxlie.SO3, jax.Array, jax.Array], None]:
        delta_R, delta_v, delta_p = carry
        accel, gyro = inputs

        # Integrate rotation: delta_R = delta_R * Exp(omega * dt)
        delta_R_new = delta_R @ jaxlie.SO3.exp(gyro * dt)

        # Rotate acceleration to preintegration frame.
        accel_rotated = delta_R @ accel

        # Integrate velocity and position.
        delta_p_new = delta_p + delta_v * dt + 0.5 * accel_rotated * dt**2
        delta_v_new = delta_v + accel_rotated * dt

        return (delta_R_new, delta_v_new, delta_p_new), None

    (delta_R, delta_v, delta_p), _ = jax.lax.scan(
        step, (delta_R, delta_v, delta_p), (accel_corrected, gyro_corrected)
    )

    return delta_R, delta_v, delta_p


# Test preintegration with true bias.
delta_R_test, delta_v_test, delta_p_test = preintegrate_imu(
    accel_measurements[0], gyro_measurements[0], true_accel_bias, true_gyro_bias, imu_dt
)
print("Test preintegration:")
print(f"  delta_R: {delta_R_test.log()}")
print(f"  delta_v: {delta_v_test}")
print(f"  delta_p: {delta_p_test}")
Test preintegration:
  delta_R: [ 0.34444118 -0.05179568 -0.05575576]
  delta_v: [-3.1116118  -0.30711094  2.8483238 ]
  delta_p: [-0.78262115 -0.05092433  0.784495  ]

Cost functions#

We define three types of costs:

  1. IMU preintegration cost: Constrains consecutive pose/velocity pairs based on integrated IMU measurements

  2. Prior cost: Anchors the first pose and velocity

  3. Bias random walk cost: Keeps bias estimates slowly varying

@jaxls.Cost.factory
def imu_cost(
    vals: jaxls.VarValues,
    pose_i: jaxls.SE3Var,
    vel_i: VelocityVar,
    pose_j: jaxls.SE3Var,
    vel_j: VelocityVar,
    bias: BiasVar,
    accel_meas: jax.Array,
    gyro_meas: jax.Array,
    dt_total: float,
    dt_imu: float,
    gravity: jax.Array,
) -> jax.Array:
    """IMU preintegration cost between two keyframes.

    Computes residual between predicted and measured relative motion.
    """
    # Get current estimates.
    T_i = vals[pose_i]
    T_j = vals[pose_j]
    v_i = vals[vel_i]
    v_j = vals[vel_j]
    b = vals[bias]
    accel_bias, gyro_bias = b[:3], b[3:]

    R_i = T_i.rotation()
    p_i = T_i.translation()
    R_j = T_j.rotation()
    p_j = T_j.translation()

    # Preintegrate with current bias estimate.
    delta_R, delta_v, delta_p = preintegrate_imu(
        accel_meas, gyro_meas, accel_bias, gyro_bias, dt_imu
    )

    # Predicted relative motion from states.
    # Position: p_j = p_i + v_i * dt + 0.5 * g * dt^2 + R_i * delta_p.
    p_j_pred = p_i + v_i * dt_total + 0.5 * gravity * dt_total**2 + R_i @ delta_p

    # Velocity: v_j = v_i + g * dt + R_i * delta_v.
    v_j_pred = v_i + gravity * dt_total + R_i @ delta_v

    # Rotation: R_j = R_i * delta_R.
    R_j_pred = R_i @ delta_R

    # Compute residuals.
    # Position residual (meters)
    r_p = (p_j - p_j_pred) * 10.0

    # Velocity residual (m/s)
    r_v = (v_j - v_j_pred) * 10.0

    # Rotation residual (radians) - weight higher since gyro is more accurate.
    r_R = (R_j_pred.inverse() @ R_j).log() * 100.0

    return jnp.concatenate([r_p, r_v, r_R])


@jaxls.Cost.factory
def pose_prior_cost(
    vals: jaxls.VarValues,
    var: jaxls.SE3Var,
    target: jaxlie.SE3,
) -> jax.Array:
    """Prior on SE3 pose (both position and orientation)."""
    error = (vals[var].inverse() @ target).log()
    # Weight position and rotation.
    return jnp.concatenate(
        [
            error[:3] * 50.0,  # Translation (meters)
            error[3:] * 100.0,  # Rotation (radians)
        ]
    )


@jaxls.Cost.factory
def velocity_prior_cost(
    vals: jaxls.VarValues,
    var: VelocityVar,
    target: jax.Array,
) -> jax.Array:
    """Prior on velocity."""
    return (vals[var] - target) * 50.0


@jaxls.Cost.factory
def bias_prior_cost(
    vals: jaxls.VarValues,
    var: BiasVar,
    target: jax.Array,
) -> jax.Array:
    """Prior on IMU bias (weak, allows estimation)."""
    # Very weak prior - we want the bias to be estimated from IMU residuals.
    return (vals[var] - target) * 0.1

Solving#

# Create variables.
pose_vars = jaxls.SE3Var(id=jnp.arange(n_keyframes))
vel_vars = VelocityVar(id=jnp.arange(n_keyframes))
bias_var = BiasVar(id=0)  # Single bias variable (assumed constant)

# Sparse GPS-like position priors (e.g., from occasional GPS fixes)
# These provide just enough constraint for bias observability.
gps_interval = 10  # GPS fix every 10 keyframes (5 seconds)
gps_indices = jnp.arange(0, n_keyframes, gps_interval)
gps_poses = jaxlie.SE3(wxyz_xyz=true_poses.wxyz_xyz[gps_indices])

# IMU cost indices (consecutive keyframe pairs)
n_intervals = n_keyframes - 1
imu_i_ids = jnp.arange(n_intervals)
imu_j_ids = jnp.arange(1, n_keyframes)

# Tile scalar parameters for batching.
gravity_batched = jnp.tile(gravity[None, :], (n_intervals, 1))

# Build costs using batched construction.
costs: list[jaxls.Cost] = [
    # Sparse GPS pose priors.
    pose_prior_cost(jaxls.SE3Var(id=gps_indices), gps_poses),
    # Anchor start velocity.
    velocity_prior_cost(VelocityVar(id=0), true_velocities[0]),
    # Weak prior on bias (centered at zero)
    bias_prior_cost(bias_var, jnp.zeros(6)),
    # IMU preintegration costs between consecutive keyframes (batched)
    imu_cost(
        jaxls.SE3Var(id=imu_i_ids),
        VelocityVar(id=imu_i_ids),
        jaxls.SE3Var(id=imu_j_ids),
        VelocityVar(id=imu_j_ids),
        bias_var,
        accel_measurements,
        gyro_measurements,
        keyframe_dt,
        imu_dt,
        gravity_batched,
    ),
]

print(f"Created {len(costs)} batched cost objects")
print(f"GPS priors at keyframes: {list(gps_indices)} (every {gps_interval} keyframes)")
Created 4 batched cost objects
GPS priors at keyframes: [Array(0, dtype=int32), Array(10, dtype=int32), Array(20, dtype=int32), Array(30, dtype=int32)] (every 10 keyframes)
# Initial values: first pose/velocity from prior, rest from dead reckoning.
# We start with zero bias estimate.


@jax.jit
def dead_reckon_trajectory(
    initial_pose: jaxlie.SE3,
    initial_vel: jax.Array,
    accel_meas: jax.Array,
    gyro_meas: jax.Array,
) -> tuple[jaxlie.SE3, jax.Array]:
    """Dead reckon forward with zero bias assumption (vectorized with scan).

    Args:
        initial_pose: Starting pose
        initial_vel: Starting velocity (3,)
        accel_meas: Accelerometer measurements (n_intervals, n_measurements, 3)
        gyro_meas: Gyroscope measurements (n_intervals, n_measurements, 3)

    Returns:
        Tuple of (poses, velocities) for all keyframes
    """
    zero_bias = jnp.zeros(3)

    def step(
        carry: tuple[jaxlie.SE3, jax.Array],
        inputs: tuple[jax.Array, jax.Array],
    ) -> tuple[tuple[jaxlie.SE3, jax.Array], tuple[jax.Array, jax.Array]]:
        pose, vel = carry
        accel, gyro = inputs

        # Preintegrate this interval.
        delta_R, delta_v, delta_p = preintegrate_imu(
            accel, gyro, zero_bias, zero_bias, imu_dt
        )

        R_i = pose.rotation()
        p_i = pose.translation()

        # Propagate.
        p_next = (
            p_i + vel * keyframe_dt + 0.5 * gravity * keyframe_dt**2 + R_i @ delta_p
        )
        v_next = vel + gravity * keyframe_dt + R_i @ delta_v
        R_next = R_i @ delta_R

        next_pose = jaxlie.SE3.from_rotation_and_translation(R_next, p_next)
        return (next_pose, v_next), (next_pose.wxyz_xyz, v_next)

    initial_carry = (initial_pose, initial_vel)
    _, (pose_wxyz_xyz, velocities) = jax.lax.scan(
        step, initial_carry, (accel_meas, gyro_meas)
    )

    # Prepend initial state.
    all_poses = jaxlie.SE3(
        wxyz_xyz=jnp.concatenate([initial_pose.wxyz_xyz[None], pose_wxyz_xyz])
    )
    all_velocities = jnp.concatenate([initial_vel[None], velocities])

    return all_poses, all_velocities


# Get initial pose and velocity from ground truth.
initial_pose = jaxlie.SE3(wxyz_xyz=true_poses.wxyz_xyz[0])
initial_vel = true_velocities[0]

initial_poses, initial_velocities = dead_reckon_trajectory(
    initial_pose, initial_vel, accel_measurements, gyro_measurements
)

print(f"Initial dead-reckoned end position: {initial_poses.translation()[-1]}")
print(f"True end position: {true_poses.translation()[-1]}")
Initial dead-reckoned end position: [-14.247985  10.484206  19.88098 ]
True end position: [-1.545084  -1.4694623  3.2366452]
# Create initial values.
initial_vals = jaxls.VarValues.make(
    [
        pose_vars.with_value(initial_poses),
        vel_vars.with_value(initial_velocities),
        bias_var.with_value(jnp.zeros(6)),
    ]
)

# Build the problem.
problem = jaxls.LeastSquaresProblem(costs, [pose_vars, vel_vars, bias_var])

# Visualize the problem structure structure.
problem.show()
# Analyze and solve.
problem = problem.analyze()

solution = problem.solve(
    initial_vals,
    termination=jaxls.TerminationConfig(cost_tolerance=1e-8),
)
INFO     | Building optimization problem with 45 terms and 81 variables: 45 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 4 costs, 1 variables each: pose_prior_cost
INFO     | Vectorizing group with 1 costs, 1 variables each: velocity_prior_cost
INFO     | Vectorizing group with 39 costs, 5 variables each: imu_cost
INFO     | Vectorizing group with 1 costs, 1 variables each: bias_prior_cost
INFO     |  step #0: cost=526848.7500 lambd=0.0005 inexact_tol=1.0e-02
INFO     |      - pose_prior_cost(4): 526848.75000 (avg 21952.03125)
INFO     |      - velocity_prior_cost(1): 0.00000 (avg 0.00000)
INFO     |      - imu_cost(39):   0.00000 (avg 0.00000)
INFO     |      - bias_prior_cost(1): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=3.70e+04 cost_prev=526848.7500 cost_new=350.6612
INFO     |  step #1: cost=350.6612 lambd=0.0003 inexact_tol=1.0e-02
INFO     |      - pose_prior_cost(4): 23.79979 (avg 0.99166)
INFO     |      - velocity_prior_cost(1): 0.06903 (avg 0.02301)
INFO     |      - imu_cost(39):   326.78799 (avg 0.93102)
INFO     |      - bias_prior_cost(1): 0.00440 (avg 0.00073)
INFO     |      accepted=True ATb_norm=3.32e+02 cost_prev=350.6612 cost_new=10.8448
INFO     |  step #2: cost=10.8448 lambd=0.0001 inexact_tol=7.2e-05
INFO     |      - pose_prior_cost(4): 0.00997 (avg 0.00042)
INFO     |      - velocity_prior_cost(1): 0.00021 (avg 0.00007)
INFO     |      - imu_cost(39):   10.83304 (avg 0.03086)
INFO     |      - bias_prior_cost(1): 0.00153 (avg 0.00025)
INFO     |      accepted=True ATb_norm=6.96e+01 cost_prev=10.8448 cost_new=0.0175
INFO     |  step #3: cost=0.0175 lambd=0.0001 inexact_tol=7.2e-05
INFO     |      - pose_prior_cost(4): 0.00423 (avg 0.00018)
INFO     |      - velocity_prior_cost(1): 0.00001 (avg 0.00000)
INFO     |      - imu_cost(39):   0.01169 (avg 0.00003)
INFO     |      - bias_prior_cost(1): 0.00156 (avg 0.00026)
INFO     |      accepted=True ATb_norm=1.77e-01 cost_prev=0.0175 cost_new=0.0174
INFO     |  step #4: cost=0.0174 lambd=0.0000 inexact_tol=5.8e-06
INFO     |      - pose_prior_cost(4): 0.00423 (avg 0.00018)
INFO     |      - velocity_prior_cost(1): 0.00001 (avg 0.00000)
INFO     |      - imu_cost(39):   0.01163 (avg 0.00003)
INFO     |      - bias_prior_cost(1): 0.00156 (avg 0.00026)
INFO     |      accepted=False ATb_norm=1.97e-02 cost_prev=0.0174 cost_new=0.0174
INFO     | Terminated @ iteration #5: cost=0.0174 criteria=[0 0 1], term_deltas=6.2e-05,2.9e-03,3.8e-07
# Extract results.
estimated_poses = solution[pose_vars]
estimated_velocities = solution[vel_vars]
estimated_bias = solution[bias_var]

print("Bias estimation:")
print(f"  True:      accel={true_accel_bias}, gyro={true_gyro_bias}")
print(f"  Estimated: accel={estimated_bias[:3]}, gyro={estimated_bias[3:]}")
print(
    f"  Error:     accel={jnp.abs(estimated_bias[:3] - true_accel_bias)}, "
    f"gyro={jnp.abs(estimated_bias[3:] - true_gyro_bias)}"
)
Bias estimation:
  True:      accel=[ 0.3  -0.2   0.15], gyro=[ 0.02  -0.01   0.005]
  Estimated: accel=[ 0.31533828 -0.18791492  0.14483762], gyro=[ 0.02158698 -0.00830974  0.00490012]
  Error:     accel=[0.01533827 0.01208508 0.00516239], gyro=[1.5869755e-03 1.6902629e-03 9.9882483e-05]

Visualization#

Compare the estimated trajectory with ground truth and the dead-reckoned initialization:

# Extract positions.
true_positions = np.array(true_poses.translation())
initial_positions = np.array(initial_poses.translation())
estimated_positions = np.array(estimated_poses.translation())

# Compute errors.
position_errors = jnp.linalg.norm(
    estimated_poses.translation() - true_poses.translation(), axis=-1
)
print(
    f"Position errors: mean={float(jnp.mean(position_errors)):.4f}m, "
    f"max={float(jnp.max(position_errors)):.4f}m"
)
Position errors: mean=0.0406m, max=0.1023m

Hide code cell source

fig = go.Figure()

# Ground truth trajectory.
fig.add_trace(
    go.Scatter3d(
        x=true_positions[:, 0],
        y=true_positions[:, 1],
        z=true_positions[:, 2],
        mode="lines+markers",
        line=dict(color="forestgreen", width=4),
        marker=dict(size=4, color="forestgreen"),
        name="Ground truth",
        hovertemplate="GT: (%{x:.2f}, %{y:.2f}, %{z:.2f})<extra></extra>",
    )
)

# Dead-reckoned (initial) trajectory.
fig.add_trace(
    go.Scatter3d(
        x=initial_positions[:, 0],
        y=initial_positions[:, 1],
        z=initial_positions[:, 2],
        mode="lines+markers",
        line=dict(color="tomato", width=2, dash="dash"),
        marker=dict(size=3, color="tomato"),
        name="Dead reckoning (init)",
        hovertemplate="Init: (%{x:.2f}, %{y:.2f}, %{z:.2f})<extra></extra>",
    )
)

# Optimized trajectory.
fig.add_trace(
    go.Scatter3d(
        x=estimated_positions[:, 0],
        y=estimated_positions[:, 1],
        z=estimated_positions[:, 2],
        mode="lines+markers",
        line=dict(color="steelblue", width=4),
        marker=dict(size=5, color="steelblue"),
        name="Optimized",
        hovertemplate="Opt: (%{x:.2f}, %{y:.2f}, %{z:.2f})<extra></extra>",
    )
)

# Start marker.
fig.add_trace(
    go.Scatter3d(
        x=[true_positions[0, 0]],
        y=[true_positions[0, 1]],
        z=[true_positions[0, 2]],
        mode="markers",
        marker=dict(size=10, color="green", symbol="diamond"),
        name="Start",
        hovertemplate="Start<extra></extra>",
    )
)

# Layout.
fig.update_layout(
    scene=dict(
        xaxis=dict(title="X (m)", showbackground=False),
        yaxis=dict(title="Y (m)", showbackground=False),
        zaxis=dict(title="Z (m)", showbackground=False),
        aspectmode="data",
        camera=dict(eye=dict(x=1.5, y=1.5, z=1.0)),
    ),
    height=550,
    margin=dict(t=30, b=20, l=20, r=20),
    legend=dict(x=0.02, y=0.98, bgcolor="rgba(255,255,255,0.8)"),
)

HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

Key observations#

The optimizer jointly estimates poses, velocities, and IMU biases by minimizing residuals between predicted and measured relative motion:

  • Bias calibration: Accelerometer bias becomes observable because gravity projects differently onto accelerometer axes at varied orientations. With roll and pitch changes throughout the trajectory, the optimizer can separate true acceleration from constant bias.

  • Trajectory correction: Dead reckoning (red dashed) drifts due to the biased IMU measurements. The optimizer corrects this drift to match the ground truth (green).

  • Preintegration efficiency: The relative motion constraints between keyframes are computed once from raw IMU data and reused during optimization.