IMU fusion + calibration#
In this notebook, we solve a sensor fusion problem: combining IMU measurements with sparse position fixes to estimate poses, velocities, and sensor biases.
Inputs: High-rate accelerometer/gyroscope data, occasional GPS-like position priors
Outputs: Trajectory (poses + velocities) and calibrated IMU biases
Features used:
SE3Varfor 6-DOF posesVarsubclassing for velocity and IMU bias variables@jaxls.Cost.factoryfor preintegration costsBatched construction for efficient problem building
import jax
import jax.numpy as jnp
import jaxlie
import jaxls
import numpy as np
Custom variables#
In addition to SE3 poses, we need:
Velocity: 3D linear velocity in world frame
IMU Bias: 6D vector (3 accelerometer + 3 gyroscope biases)
class VelocityVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(3)):
"""3D velocity in world frame."""
class BiasVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(6)):
"""IMU bias: [accel_bias (3), gyro_bias (3)]."""
Generate synthetic IMU data#
We simulate a vehicle moving in a circular trajectory. The ground truth trajectory gives us:
True poses, velocities at each keyframe
Synthetic IMU measurements (with noise and bias) between keyframes
# Simulation parameters.
n_keyframes = 40
imu_rate = 100 # Hz
keyframe_dt = 0.5 # seconds between keyframes
imu_measurements_per_keyframe = int(imu_rate * keyframe_dt)
imu_dt = keyframe_dt / imu_measurements_per_keyframe
# True IMU biases (constant) - larger values for clear demonstration.
true_accel_bias = jnp.array([0.3, -0.2, 0.15]) # m/s^2
true_gyro_bias = jnp.array([0.02, -0.01, 0.005]) # rad/s
true_bias = jnp.concatenate([true_accel_bias, true_gyro_bias])
# Gravity in world frame.
gravity = jnp.array([0.0, 0.0, -9.81])
# Noise parameters (low to make bias estimation cleaner)
accel_noise_std = 0.01
gyro_noise_std = 0.001
print(f"Keyframes: {n_keyframes}")
print(f"Total time: {n_keyframes * keyframe_dt:.1f}s")
print(f"IMU measurements per interval: {imu_measurements_per_keyframe}")
print(f"True biases: accel={true_accel_bias}, gyro={true_gyro_bias}")
Keyframes: 40
Total time: 20.0s
IMU measurements per interval: 50
True biases: accel=[ 0.3 -0.2 0.15], gyro=[ 0.02 -0.01 0.005]
# Helper for displaying orientation variation.
def so3_to_rpy(rot: jaxlie.SO3) -> jax.Array:
"""Extract roll, pitch, yaw from SO3."""
R = rot.as_matrix()
pitch = jnp.arcsin(-R[2, 0])
roll = jnp.arctan2(R[2, 1], R[2, 2])
yaw = jnp.arctan2(R[1, 0], R[0, 0])
return jnp.array([roll, pitch, yaw])
def trajectory_state(
t: jax.Array,
scale: float = 5.0,
height_amplitude: float = 3.0,
base_height: float = 5.0,
period: float = 10.0,
) -> tuple[jax.Array, jax.Array]:
"""Compute position and Euler angles at time t.
Returns (position, euler_angles) where euler_angles = [roll, pitch, yaw].
"""
theta = 2 * jnp.pi * t / period
dtheta = 2 * jnp.pi / period
# Position.
x = scale * jnp.sin(theta)
y = scale * jnp.sin(theta) * jnp.cos(theta)
z = base_height + height_amplitude * jnp.sin(2 * theta)
position = jnp.array([x, y, z])
# Velocity (needed for yaw and pitch computation)
vx = scale * jnp.cos(theta) * dtheta
vy = scale * jnp.cos(2 * theta) * dtheta
vz = 2 * height_amplitude * jnp.cos(2 * theta) * dtheta
# Euler angles.
roll = 0.5 * jnp.sin(2 * theta)
yaw = jnp.arctan2(vy, vx)
v_horiz = jnp.sqrt(vx**2 + vy**2)
pitch = jnp.arctan2(vz, v_horiz)
euler = jnp.array([roll, pitch, yaw])
return position, euler
def generate_trajectory_at_times(
times: jax.Array,
scale: float = 5.0,
height_amplitude: float = 3.0,
base_height: float = 5.0,
period: float = 10.0,
) -> tuple[jaxlie.SE3, jax.Array, jax.Array, jax.Array]:
"""Generate trajectory using autodiff for velocities/accelerations.
Returns poses, velocities, accelerations, and body-frame angular velocities.
"""
def state_at_t(t: jax.Array) -> tuple[jax.Array, jax.Array]:
return trajectory_state(t, scale, height_amplitude, base_height, period)
# Use autodiff to compute derivatives.
def position_at_t(t: jax.Array) -> jax.Array:
return state_at_t(t)[0]
def euler_at_t(t: jax.Array) -> jax.Array:
return state_at_t(t)[1]
# Velocity = d(position)/dt, Acceleration = d²(position)/dt².
velocity_fn = jax.jacfwd(position_at_t)
accel_fn = jax.jacfwd(velocity_fn)
# Euler angle rates.
euler_dot_fn = jax.jacfwd(euler_at_t)
def compute_state(
t: jax.Array,
) -> tuple[jaxlie.SE3, jax.Array, jax.Array, jax.Array]:
position, euler = state_at_t(t)
velocity = velocity_fn(t)
acceleration = accel_fn(t)
euler_dot = euler_dot_fn(t)
roll, pitch, yaw = euler
roll_dot, pitch_dot, yaw_dot = euler_dot
# Body angular velocity from Euler angle rates (ZYX convention)
omega_x = roll_dot - yaw_dot * jnp.sin(pitch)
omega_y = pitch_dot * jnp.cos(roll) + yaw_dot * jnp.cos(pitch) * jnp.sin(roll)
omega_z = -pitch_dot * jnp.sin(roll) + yaw_dot * jnp.cos(pitch) * jnp.cos(roll)
omega_body = jnp.array([omega_x, omega_y, omega_z])
# Build pose from Euler angles.
rot = (
jaxlie.SO3.from_z_radians(yaw)
@ jaxlie.SO3.from_y_radians(pitch)
@ jaxlie.SO3.from_x_radians(roll)
)
pose = jaxlie.SE3.from_rotation_and_translation(rot, position)
return pose, velocity, acceleration, omega_body
# Vectorize over all times.
poses, velocities, accelerations, angular_velocities = jax.vmap(compute_state)(
times
)
return poses, velocities, accelerations, angular_velocities
def generate_imu_measurements(
accel_bias: jax.Array,
gyro_bias: jax.Array,
key: jax.Array,
) -> tuple[jax.Array, jax.Array]:
"""Generate smooth IMU measurements by sampling the analytic trajectory.
Args:
accel_bias: Accelerometer bias (3,)
gyro_bias: Gyroscope bias (3,)
key: JAX random key
Returns:
Tuple of (accel_measurements, gyro_measurements) each with shape
(n_keyframes-1, imu_measurements_per_keyframe, 3)
"""
n_intervals = n_keyframes - 1
# Generate times for all IMU samples (centered in each dt interval)
interval_starts = jnp.arange(n_intervals) * keyframe_dt
imu_offsets = (jnp.arange(imu_measurements_per_keyframe) + 0.5) * imu_dt
times_imu = (interval_starts[:, None] + imu_offsets[None, :]).flatten()
# Get trajectory at IMU times using autodiff.
poses_imu, _, accelerations_imu, omega_body = generate_trajectory_at_times(
times_imu
)
# Accelerometer measures specific force: a_body = R^T @ (a_world - g)
rotations = poses_imu.rotation()
specific_force_world = accelerations_imu - gravity[None, :]
accel_body = jax.vmap(lambda R, a: R.inverse() @ a)(rotations, specific_force_world)
# Gyroscope measures angular velocity in body frame.
gyro_body = omega_body
# Reshape to (n_intervals, imu_measurements_per_keyframe, 3)
accel_body = accel_body.reshape(n_intervals, imu_measurements_per_keyframe, 3)
gyro_body = gyro_body.reshape(n_intervals, imu_measurements_per_keyframe, 3)
# Add bias and noise.
keys = jax.random.split(key, 2)
accel_noise = jax.random.normal(keys[0], accel_body.shape) * accel_noise_std
gyro_noise = jax.random.normal(keys[1], gyro_body.shape) * gyro_noise_std
accels = accel_body + accel_bias[None, None, :] + accel_noise
gyros = gyro_body + gyro_bias[None, None, :] + gyro_noise
return accels, gyros
# Generate ground truth trajectory at keyframe times.
keyframe_times = jnp.arange(n_keyframes) * keyframe_dt
true_poses, true_velocities, _, _ = generate_trajectory_at_times(keyframe_times)
# Generate IMU measurements.
key = jax.random.PRNGKey(42)
accel_measurements, gyro_measurements = generate_imu_measurements(
true_accel_bias, true_gyro_bias, key
)
print(f"Generated {n_keyframes} keyframe poses")
print(f"IMU measurements shape: {accel_measurements.shape}")
Generated 40 keyframe poses
IMU measurements shape: (39, 50, 3)
The generated IMU measurements include accelerometer (linear acceleration in body frame) and gyroscope (angular velocity) readings. Vertical dotted lines indicate keyframe boundaries where preintegration intervals begin/end:
IMU preintegration#
Preintegration combines many high-frequency IMU measurements into a single relative motion constraint. Given IMU measurements between times \(i\) and \(j\), we compute:
\(\Delta R_{ij}\): Relative rotation
\(\Delta v_{ij}\): Velocity change (in frame \(i\))
\(\Delta p_{ij}\): Position change (in frame \(i\))
These preintegrated measurements are independent of the absolute pose/velocity at time \(i\), which allows efficient re-linearization during optimization.
def preintegrate_imu(
accel_meas: jax.Array, # (n_measurements, 3)
gyro_meas: jax.Array, # (n_measurements, 3)
accel_bias: jax.Array, # (3,)
gyro_bias: jax.Array, # (3,)
dt: float,
) -> tuple[jaxlie.SO3, jax.Array, jax.Array]:
"""Preintegrate IMU measurements between two keyframes.
Args:
accel_meas: Accelerometer measurements (n_measurements, 3)
gyro_meas: Gyroscope measurements (n_measurements, 3)
accel_bias: Accelerometer bias (3,)
gyro_bias: Gyroscope bias (3,)
dt: Time step between measurements
Returns:
Tuple of (delta_R, delta_v, delta_p):
delta_R: Relative rotation (SO3)
delta_v: Velocity increment in body frame (3,)
delta_p: Position increment in body frame (3,)
"""
# Bias-corrected measurements.
accel_corrected = accel_meas - accel_bias[None, :]
gyro_corrected = gyro_meas - gyro_bias[None, :]
# Initialize preintegrated values.
delta_R = jaxlie.SO3.identity()
delta_v = jnp.zeros(3)
delta_p = jnp.zeros(3)
def step(
carry: tuple[jaxlie.SO3, jax.Array, jax.Array],
inputs: tuple[jax.Array, jax.Array],
) -> tuple[tuple[jaxlie.SO3, jax.Array, jax.Array], None]:
delta_R, delta_v, delta_p = carry
accel, gyro = inputs
# Integrate rotation: delta_R = delta_R * Exp(omega * dt)
delta_R_new = delta_R @ jaxlie.SO3.exp(gyro * dt)
# Rotate acceleration to preintegration frame.
accel_rotated = delta_R @ accel
# Integrate velocity and position.
delta_p_new = delta_p + delta_v * dt + 0.5 * accel_rotated * dt**2
delta_v_new = delta_v + accel_rotated * dt
return (delta_R_new, delta_v_new, delta_p_new), None
(delta_R, delta_v, delta_p), _ = jax.lax.scan(
step, (delta_R, delta_v, delta_p), (accel_corrected, gyro_corrected)
)
return delta_R, delta_v, delta_p
# Test preintegration with true bias.
delta_R_test, delta_v_test, delta_p_test = preintegrate_imu(
accel_measurements[0], gyro_measurements[0], true_accel_bias, true_gyro_bias, imu_dt
)
print("Test preintegration:")
print(f" delta_R: {delta_R_test.log()}")
print(f" delta_v: {delta_v_test}")
print(f" delta_p: {delta_p_test}")
Test preintegration:
delta_R: [ 0.34444118 -0.05179568 -0.05575576]
delta_v: [-3.1116118 -0.30711094 2.8483238 ]
delta_p: [-0.78262115 -0.05092433 0.784495 ]
Cost functions#
We define three types of costs:
IMU preintegration cost: Constrains consecutive pose/velocity pairs based on integrated IMU measurements
Prior cost: Anchors the first pose and velocity
Bias random walk cost: Keeps bias estimates slowly varying
@jaxls.Cost.factory
def imu_cost(
vals: jaxls.VarValues,
pose_i: jaxls.SE3Var,
vel_i: VelocityVar,
pose_j: jaxls.SE3Var,
vel_j: VelocityVar,
bias: BiasVar,
accel_meas: jax.Array,
gyro_meas: jax.Array,
dt_total: float,
dt_imu: float,
gravity: jax.Array,
) -> jax.Array:
"""IMU preintegration cost between two keyframes.
Computes residual between predicted and measured relative motion.
"""
# Get current estimates.
T_i = vals[pose_i]
T_j = vals[pose_j]
v_i = vals[vel_i]
v_j = vals[vel_j]
b = vals[bias]
accel_bias, gyro_bias = b[:3], b[3:]
R_i = T_i.rotation()
p_i = T_i.translation()
R_j = T_j.rotation()
p_j = T_j.translation()
# Preintegrate with current bias estimate.
delta_R, delta_v, delta_p = preintegrate_imu(
accel_meas, gyro_meas, accel_bias, gyro_bias, dt_imu
)
# Predicted relative motion from states.
# Position: p_j = p_i + v_i * dt + 0.5 * g * dt^2 + R_i * delta_p.
p_j_pred = p_i + v_i * dt_total + 0.5 * gravity * dt_total**2 + R_i @ delta_p
# Velocity: v_j = v_i + g * dt + R_i * delta_v.
v_j_pred = v_i + gravity * dt_total + R_i @ delta_v
# Rotation: R_j = R_i * delta_R.
R_j_pred = R_i @ delta_R
# Compute residuals.
# Position residual (meters)
r_p = (p_j - p_j_pred) * 10.0
# Velocity residual (m/s)
r_v = (v_j - v_j_pred) * 10.0
# Rotation residual (radians) - weight higher since gyro is more accurate.
r_R = (R_j_pred.inverse() @ R_j).log() * 100.0
return jnp.concatenate([r_p, r_v, r_R])
@jaxls.Cost.factory
def pose_prior_cost(
vals: jaxls.VarValues,
var: jaxls.SE3Var,
target: jaxlie.SE3,
) -> jax.Array:
"""Prior on SE3 pose (both position and orientation)."""
error = (vals[var].inverse() @ target).log()
# Weight position and rotation.
return jnp.concatenate(
[
error[:3] * 50.0, # Translation (meters)
error[3:] * 100.0, # Rotation (radians)
]
)
@jaxls.Cost.factory
def velocity_prior_cost(
vals: jaxls.VarValues,
var: VelocityVar,
target: jax.Array,
) -> jax.Array:
"""Prior on velocity."""
return (vals[var] - target) * 50.0
@jaxls.Cost.factory
def bias_prior_cost(
vals: jaxls.VarValues,
var: BiasVar,
target: jax.Array,
) -> jax.Array:
"""Prior on IMU bias (weak, allows estimation)."""
# Very weak prior - we want the bias to be estimated from IMU residuals.
return (vals[var] - target) * 0.1
Solving#
# Create variables.
pose_vars = jaxls.SE3Var(id=jnp.arange(n_keyframes))
vel_vars = VelocityVar(id=jnp.arange(n_keyframes))
bias_var = BiasVar(id=0) # Single bias variable (assumed constant)
# Sparse GPS-like position priors (e.g., from occasional GPS fixes)
# These provide just enough constraint for bias observability.
gps_interval = 10 # GPS fix every 10 keyframes (5 seconds)
gps_indices = jnp.arange(0, n_keyframes, gps_interval)
gps_poses = jaxlie.SE3(wxyz_xyz=true_poses.wxyz_xyz[gps_indices])
# IMU cost indices (consecutive keyframe pairs)
n_intervals = n_keyframes - 1
imu_i_ids = jnp.arange(n_intervals)
imu_j_ids = jnp.arange(1, n_keyframes)
# Tile scalar parameters for batching.
gravity_batched = jnp.tile(gravity[None, :], (n_intervals, 1))
# Build costs using batched construction.
costs: list[jaxls.Cost] = [
# Sparse GPS pose priors.
pose_prior_cost(jaxls.SE3Var(id=gps_indices), gps_poses),
# Anchor start velocity.
velocity_prior_cost(VelocityVar(id=0), true_velocities[0]),
# Weak prior on bias (centered at zero)
bias_prior_cost(bias_var, jnp.zeros(6)),
# IMU preintegration costs between consecutive keyframes (batched)
imu_cost(
jaxls.SE3Var(id=imu_i_ids),
VelocityVar(id=imu_i_ids),
jaxls.SE3Var(id=imu_j_ids),
VelocityVar(id=imu_j_ids),
bias_var,
accel_measurements,
gyro_measurements,
keyframe_dt,
imu_dt,
gravity_batched,
),
]
print(f"Created {len(costs)} batched cost objects")
print(f"GPS priors at keyframes: {list(gps_indices)} (every {gps_interval} keyframes)")
Created 4 batched cost objects
GPS priors at keyframes: [Array(0, dtype=int32), Array(10, dtype=int32), Array(20, dtype=int32), Array(30, dtype=int32)] (every 10 keyframes)
# Initial values: first pose/velocity from prior, rest from dead reckoning.
# We start with zero bias estimate.
@jax.jit
def dead_reckon_trajectory(
initial_pose: jaxlie.SE3,
initial_vel: jax.Array,
accel_meas: jax.Array,
gyro_meas: jax.Array,
) -> tuple[jaxlie.SE3, jax.Array]:
"""Dead reckon forward with zero bias assumption (vectorized with scan).
Args:
initial_pose: Starting pose
initial_vel: Starting velocity (3,)
accel_meas: Accelerometer measurements (n_intervals, n_measurements, 3)
gyro_meas: Gyroscope measurements (n_intervals, n_measurements, 3)
Returns:
Tuple of (poses, velocities) for all keyframes
"""
zero_bias = jnp.zeros(3)
def step(
carry: tuple[jaxlie.SE3, jax.Array],
inputs: tuple[jax.Array, jax.Array],
) -> tuple[tuple[jaxlie.SE3, jax.Array], tuple[jax.Array, jax.Array]]:
pose, vel = carry
accel, gyro = inputs
# Preintegrate this interval.
delta_R, delta_v, delta_p = preintegrate_imu(
accel, gyro, zero_bias, zero_bias, imu_dt
)
R_i = pose.rotation()
p_i = pose.translation()
# Propagate.
p_next = (
p_i + vel * keyframe_dt + 0.5 * gravity * keyframe_dt**2 + R_i @ delta_p
)
v_next = vel + gravity * keyframe_dt + R_i @ delta_v
R_next = R_i @ delta_R
next_pose = jaxlie.SE3.from_rotation_and_translation(R_next, p_next)
return (next_pose, v_next), (next_pose.wxyz_xyz, v_next)
initial_carry = (initial_pose, initial_vel)
_, (pose_wxyz_xyz, velocities) = jax.lax.scan(
step, initial_carry, (accel_meas, gyro_meas)
)
# Prepend initial state.
all_poses = jaxlie.SE3(
wxyz_xyz=jnp.concatenate([initial_pose.wxyz_xyz[None], pose_wxyz_xyz])
)
all_velocities = jnp.concatenate([initial_vel[None], velocities])
return all_poses, all_velocities
# Get initial pose and velocity from ground truth.
initial_pose = jaxlie.SE3(wxyz_xyz=true_poses.wxyz_xyz[0])
initial_vel = true_velocities[0]
initial_poses, initial_velocities = dead_reckon_trajectory(
initial_pose, initial_vel, accel_measurements, gyro_measurements
)
print(f"Initial dead-reckoned end position: {initial_poses.translation()[-1]}")
print(f"True end position: {true_poses.translation()[-1]}")
Initial dead-reckoned end position: [-14.247985 10.484206 19.88098 ]
True end position: [-1.545084 -1.4694623 3.2366452]
# Create initial values.
initial_vals = jaxls.VarValues.make(
[
pose_vars.with_value(initial_poses),
vel_vars.with_value(initial_velocities),
bias_var.with_value(jnp.zeros(6)),
]
)
# Build the problem.
problem = jaxls.LeastSquaresProblem(costs, [pose_vars, vel_vars, bias_var])
# Visualize the problem structure structure.
problem.show()
# Analyze and solve.
problem = problem.analyze()
solution = problem.solve(
initial_vals,
termination=jaxls.TerminationConfig(cost_tolerance=1e-8),
)
INFO | Building optimization problem with 45 terms and 81 variables: 45 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO | Vectorizing group with 4 costs, 1 variables each: pose_prior_cost
INFO | Vectorizing group with 1 costs, 1 variables each: velocity_prior_cost
INFO | Vectorizing group with 39 costs, 5 variables each: imu_cost
INFO | Vectorizing group with 1 costs, 1 variables each: bias_prior_cost
INFO | step #0: cost=526848.7500 lambd=0.0005 inexact_tol=1.0e-02
INFO | - pose_prior_cost(4): 526848.75000 (avg 21952.03125)
INFO | - velocity_prior_cost(1): 0.00000 (avg 0.00000)
INFO | - imu_cost(39): 0.00000 (avg 0.00000)
INFO | - bias_prior_cost(1): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=3.70e+04 cost_prev=526848.7500 cost_new=350.6612
INFO | step #1: cost=350.6612 lambd=0.0003 inexact_tol=1.0e-02
INFO | - pose_prior_cost(4): 23.79979 (avg 0.99166)
INFO | - velocity_prior_cost(1): 0.06903 (avg 0.02301)
INFO | - imu_cost(39): 326.78799 (avg 0.93102)
INFO | - bias_prior_cost(1): 0.00440 (avg 0.00073)
INFO | accepted=True ATb_norm=3.32e+02 cost_prev=350.6612 cost_new=10.8448
INFO | step #2: cost=10.8448 lambd=0.0001 inexact_tol=7.2e-05
INFO | - pose_prior_cost(4): 0.00997 (avg 0.00042)
INFO | - velocity_prior_cost(1): 0.00021 (avg 0.00007)
INFO | - imu_cost(39): 10.83304 (avg 0.03086)
INFO | - bias_prior_cost(1): 0.00153 (avg 0.00025)
INFO | accepted=True ATb_norm=6.96e+01 cost_prev=10.8448 cost_new=0.0175
INFO | step #3: cost=0.0175 lambd=0.0001 inexact_tol=7.2e-05
INFO | - pose_prior_cost(4): 0.00423 (avg 0.00018)
INFO | - velocity_prior_cost(1): 0.00001 (avg 0.00000)
INFO | - imu_cost(39): 0.01169 (avg 0.00003)
INFO | - bias_prior_cost(1): 0.00156 (avg 0.00026)
INFO | accepted=True ATb_norm=1.77e-01 cost_prev=0.0175 cost_new=0.0174
INFO | step #4: cost=0.0174 lambd=0.0000 inexact_tol=5.8e-06
INFO | - pose_prior_cost(4): 0.00423 (avg 0.00018)
INFO | - velocity_prior_cost(1): 0.00001 (avg 0.00000)
INFO | - imu_cost(39): 0.01163 (avg 0.00003)
INFO | - bias_prior_cost(1): 0.00156 (avg 0.00026)
INFO | accepted=False ATb_norm=1.97e-02 cost_prev=0.0174 cost_new=0.0174
INFO | Terminated @ iteration #5: cost=0.0174 criteria=[0 0 1], term_deltas=6.2e-05,2.9e-03,3.8e-07
# Extract results.
estimated_poses = solution[pose_vars]
estimated_velocities = solution[vel_vars]
estimated_bias = solution[bias_var]
print("Bias estimation:")
print(f" True: accel={true_accel_bias}, gyro={true_gyro_bias}")
print(f" Estimated: accel={estimated_bias[:3]}, gyro={estimated_bias[3:]}")
print(
f" Error: accel={jnp.abs(estimated_bias[:3] - true_accel_bias)}, "
f"gyro={jnp.abs(estimated_bias[3:] - true_gyro_bias)}"
)
Bias estimation:
True: accel=[ 0.3 -0.2 0.15], gyro=[ 0.02 -0.01 0.005]
Estimated: accel=[ 0.31533828 -0.18791492 0.14483762], gyro=[ 0.02158698 -0.00830974 0.00490012]
Error: accel=[0.01533827 0.01208508 0.00516239], gyro=[1.5869755e-03 1.6902629e-03 9.9882483e-05]
Visualization#
Compare the estimated trajectory with ground truth and the dead-reckoned initialization:
# Extract positions.
true_positions = np.array(true_poses.translation())
initial_positions = np.array(initial_poses.translation())
estimated_positions = np.array(estimated_poses.translation())
# Compute errors.
position_errors = jnp.linalg.norm(
estimated_poses.translation() - true_poses.translation(), axis=-1
)
print(
f"Position errors: mean={float(jnp.mean(position_errors)):.4f}m, "
f"max={float(jnp.max(position_errors)):.4f}m"
)
Position errors: mean=0.0406m, max=0.1023m
Key observations#
The optimizer jointly estimates poses, velocities, and IMU biases by minimizing residuals between predicted and measured relative motion:
Bias calibration: Accelerometer bias becomes observable because gravity projects differently onto accelerometer axes at varied orientations. With roll and pitch changes throughout the trajectory, the optimizer can separate true acceleration from constant bias.
Trajectory correction: Dead reckoning (red dashed) drifts due to the biased IMU measurements. The optimizer corrects this drift to match the ground truth (green).
Preintegration efficiency: The relative motion constraints between keyframes are computed once from raw IMU data and reused during optimization.