Schur complement elimination#

Many least squares problems have a Hessian with a dominant variable type whose own block is block-diagonal: variables of that type never couple to each other, only to the rest of the problem. When that block is large, it is much cheaper to eliminate it analytically and solve the smaller, better- conditioned system that remains. jaxls does this automatically via a Schur-complement elimination.

This page discusses:

  • The Hessian structure that makes elimination possible, and the reduced system it produces

  • How jaxls decides what to eliminate (analyze(schur_elimination=...))

  • How each linear solver uses the reduced system

  • Verifying that the reduced solve is exact

The reduced system#

Partition the variables into a set we keep (\(c\)) and a set we eliminate (\(l\)). A damped Gauss-Newton (Levenberg-Marquardt) step solves the normal equations, written in block form as

\[\begin{split} \begin{bmatrix} H_{cc} & W \\ W^\top & V \end{bmatrix} \begin{bmatrix} \Delta c \\ \Delta l \end{bmatrix} = \begin{bmatrix} b_c \\ b_l \end{bmatrix}, \end{split}\]

where \(H = J^\top J + \lambda I\) is the damped Hessian and \(b = -J^\top r\) is the gradient. The block \(V\) is the Hessian of the eliminated variables among themselves. The structure jaxls exploits is this: when no single cost couples two eliminated variables, \(V\) is block-diagonal: one small block per eliminated variable. Block-diagonal matrices are cheap to invert.

That lets us solve the second block row for \(\Delta l\) and substitute it into the first, leaving a smaller reduced (or Schur-complement) system in \(\Delta c\) alone:

\[ S\,\Delta c = b_c - W V^{-1} b_l, \qquad S = H_{cc} - W V^{-1} W^\top. \]

Once \(\Delta c\) is known, the eliminated update is recovered by back-substitution:

\[ \Delta l = V^{-1}\,(b_l - W^\top \Delta c). \]

This helps in two ways. The reduced matrix \(S\) is much smaller than the full Hessian, and it is better conditioned, so the linear solver does much less work. Inverting \(V\), the one step that could be costly, is inexpensive because \(V\) is block-diagonal.

Bundle adjustment as an example#

Bundle adjustment is the canonical case. It jointly optimizes camera poses and 3D landmark positions; there are typically far more landmarks than cameras, and every landmark couples only to the cameras that observe it, never to another landmark. So the landmark block of the Hessian is block-diagonal (one block per landmark) and dominates the problem by size. Eliminating the landmarks leaves the small camera-only system \(S\), which is what the linear solver actually factors or iterates on. The same structure appears whenever a problem has a large, mutually-uncoupled variable type, so the technique is not specific to vision.

Automatic elimination#

Elimination is on by default. analyze() inspects the problem’s sparsity, and if a dominant block-diagonal variable type exists it builds an elimination plan; then solve() runs on the reduced system and back-substitutes. Nothing else in your code changes.

To make this concrete, we build a small synthetic bundle adjustment problem: a handful of cameras, many points, each point seen by a few cameras.

import jax
import jax.numpy as jnp
import numpy as np

import jaxls

jax.config.update("jax_enable_x64", True)


class CameraVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(6)):
    """Toy 6-DoF camera (3 rotation, 3 translation)."""


class PointVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(3)):
    """3D landmark position."""


# A few cameras, many points; each point is observed by three cameras.
rng = np.random.default_rng(0)
n_cameras = 8
n_points = 60

cam_indices = []
point_indices = []
for j in range(n_points):
    for i in rng.choice(n_cameras, size=3, replace=False):
        cam_indices.append(int(i))
        point_indices.append(j)
cam_indices = jnp.array(cam_indices)
point_indices = jnp.array(point_indices)
observations = jnp.array(rng.normal(0.0, 0.1, (cam_indices.shape[0], 2)))
W0619 03:21:42.908476 1018465 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W0619 03:21:42.910351 1018464 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W0619 03:21:42.914509 1018463 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W0619 03:21:42.921053 1018466 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W0619 03:21:42.924540 1018336 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W0619 03:21:42.925406 1018336 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W0619 03:21:42.926260 1018336 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W0619 03:21:42.927125 1018336 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.

We build a single batched cost over all observations, plus the variables:

@jaxls.Cost.factory
def reprojection_cost(vals, cam: CameraVar, point: PointVar, observed):
    p = vals[point] + vals[cam][:3]
    predicted = p[:2] / (p[2] + 5.0)
    return predicted - observed


costs = [
    reprojection_cost(
        CameraVar(cam_indices),
        PointVar(point_indices),
        observations,
    )
]
variables = [CameraVar(jnp.arange(n_cameras)), PointVar(jnp.arange(n_points))]

problem = jaxls.LeastSquaresProblem(costs, variables).analyze()
2026-06-19 03:21:43.252 | INFO     | jaxls._problem:analyze:231 - Building optimization problem with 180 terms and 68 variables: 180 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
2026-06-19 03:21:43.739 | INFO     | jaxls._problem:analyze:358 - Vectorizing group with 180 costs, 2 variables each: reprojection_cost
2026-06-19 03:21:44.405 | INFO     | jaxls._problem:analyze:450 - Variable elimination: eliminating PointVar (180 of 228 tangent dims); reduced system is 48-dimensional

The log line from analyze() (above) reports what was eliminated and how much the system shrank. We can also inspect it directly: the full tangent dimension versus the reduced (kept) dimension.

plan = problem._elimination
assert plan is not None  # An elimination plan was built.

print(f"full tangent dim:    {problem._tangent_dim}")
print(f"reduced (kept) dim:  {plan.reduced_dim}")
print(f"eliminated dim:      {problem._tangent_dim - plan.reduced_dim}")
full tangent dim:    228
reduced (kept) dim:  48
eliminated dim:      180

The PointVar block (3 dims × 60 points = 180) is eliminated, leaving only the camera system (6 dims × 8 cameras = 48) for the linear solver to handle directly.

How the choice is made#

The automatic rule (schur_elimination="auto", the default):

  • A variable type is eligible to be eliminated only if no single cost couples two variables of that type, which is what makes its Hessian block block-diagonal. In a pose graph, where a cost couples two poses, the pose type is not eligible, so no elimination happens and the full system is solved.

  • jaxls eliminates a set of eligible types, not just one: it adds them greedily by total tangent size (preferring many small variables), as long as the combined block stays block-diagonal, and always keeps at least one type for the reduced system. In bundle adjustment with both 3D points and per-point colors, for instance, both can be eliminated together.

  • Elimination is only used if the eliminated block covers at least 5% of the tangent dimension. Below that, the per-iteration bookkeeping starts to outweigh the smaller system, so the full solve is used instead.

Controlling elimination#

analyze(schur_elimination=...) takes three forms:

  • "auto" (default): infer a dominant block-diagonal type as above.

  • "off": skip elimination and solve the full system. Useful for debugging or benchmarking against the non-eliminated solve.

  • a tuple of variable types, e.g. (PointVar,) or (PointVar, ColorVar): eliminate exactly those types. Each must be block-diagonal, or analyze() raises a ValueError.

Multiple types can be eliminated at once (as the tuple form and the automatic rule above both allow); they are removed together in one Schur step. What is not supported is nested elimination: eliminating a further type from the already-reduced system in a second step.

# Explicit: eliminate exactly the point variables (same as "auto" picks here).
problem_explicit = jaxls.LeastSquaresProblem(costs, variables).analyze(
    schur_elimination=(PointVar,)
)
print("explicit plan present:", problem_explicit._elimination is not None)

# Off: solve the full system.
problem_off = jaxls.LeastSquaresProblem(costs, variables).analyze(
    schur_elimination="off"
)
print("plan when off:        ", problem_off._elimination)
explicit plan present: True
plan when off:         None
2026-06-19 03:21:44.418 | INFO     | jaxls._problem:analyze:231 - Building optimization problem with 180 terms and 68 variables: 180 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
2026-06-19 03:21:44.425 | INFO     | jaxls._problem:analyze:358 - Vectorizing group with 180 costs, 2 variables each: reprojection_cost
2026-06-19 03:21:44.440 | INFO     | jaxls._problem:analyze:450 - Variable elimination: eliminating PointVar (180 of 228 tangent dims); reduced system is 48-dimensional
2026-06-19 03:21:44.441 | INFO     | jaxls._problem:analyze:231 - Building optimization problem with 180 terms and 68 variables: 180 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
2026-06-19 03:21:44.445 | INFO     | jaxls._problem:analyze:358 - Vectorizing group with 180 costs, 2 variables each: reprojection_cost

Each solver uses the reduced system#

All three linear solvers consume the same elimination plan; they differ only in how they handle the reduced matrix \(S\):

linear_solver

Reduced system \(S\,\Delta c = \tilde b\) is solved by

dense_cholesky

Forming \(S\) densely and Cholesky-factoring it.

conjugate_gradient (default)

Matrix-free CG: \(S\) is never formed; each product applies \(H_{cc}\), \(V^{-1}\), and \(W\) in turn.

cholmod

Assembling \(S\) as a sparse matrix and factoring it with CHOLMOD, the Ceres/g2o “Schur + sparse-direct” combination.

In every case the block-diagonal elimination of \(V\) runs on-device, and only the small, well-conditioned reduced system reaches the linear solver.

The reduced solve is exact#

Eliminating \(l\) is exact algebra, not an approximation: the reduced step produces the identical update to solving the full damped system directly. We can check this by comparing a Schur solve against the same problem solved with elimination turned off.

init = jaxls.VarValues.make(
    [
        CameraVar(jnp.arange(n_cameras)).with_value(
            jnp.array(rng.normal(0.0, 0.05, (n_cameras, 6)))
        ),
        PointVar(jnp.arange(n_points)).with_value(
            jnp.array(rng.normal(0.0, 0.05, (n_points, 3)))
        ),
    ]
)

solve_kwargs = dict(
    linear_solver="dense_cholesky",
    termination=jaxls.TerminationConfig(max_iterations=10, early_termination=False),
    verbose=False,
    return_summary=True,
)

_, summary_schur = problem.solve(init, **solve_kwargs)
_, summary_full = problem_off.solve(init, **solve_kwargs)

c_schur = np.asarray(summary_schur.cost_history[:10])
c_full = np.asarray(summary_full.cost_history[:10])
rel = np.abs(c_schur - c_full) / np.abs(c_full)
print(f"max relative cost-trajectory difference: {rel.max():.2e}")
max relative cost-trajectory difference: 6.81e-13

The two trajectories match to roundoff: the reduced solve traces exactly the same Levenberg-Marquardt path as the full solve, at a fraction of the linear-algebra cost.

See also#

  • Bundle adjustment: a full bundle adjustment example where elimination is active automatically.

  • Sparse matrices: how jaxls represents and multiplies the sparse Jacobian the elimination is built on.

  • Tips and gotchas: choosing a linear solver.