Sparse matrices#

jaxls solves nonlinear least squares using Levenberg-Marquardt, which is based on the Gauss-Newton method.

This page discusses:

  • The Gauss-Newton approximation and why Jacobian structure matters

  • jaxls’s BlockRowSparseMatrix representation, which is designed to be GPU-friendly

  • Linear solvers and preconditioning strategies

The Gauss-Newton approximation#

Newton’s method for minimizing \(f(x) = \frac{1}{2}\sum_i \|r_i(x)\|^2\) requires the Hessian \(\nabla^2 f\), which is expensive to compute. Gauss-Newton instead approximates the Hessian using only first-order information:

\[\nabla^2 f \approx J^\top J\]

where \(J = \partial r / \partial x\) is the Jacobian of the stacked residual vector \(r(x) = [r_0(x)^\top, r_1(x)^\top, \ldots]^\top\) with respect to the optimization variables.

Why Jacobians are sparse#

Residual terms in nonlinear least squares often depend on only a subset of the optimization variables. For example, a pairwise cost between two poses only depends on those two poses, not the hundreds of other poses in a SLAM problem. This produces sparse Jacobians: row block \(i\) of \(J\) has non-zeros only in columns corresponding to the variables that \(r_i\) depends on.

This sparsity pattern corresponds to the adjacency matrix of a bipartite factor graph:

  • Variables (circles): The parameters we’re optimizing

  • Costs (squares): Residual functions that connect to their dependent variables

Consider an optimization problem with variables \(x_0\), \(x_1\), \(x_2\) and three types of costs: A (unary), B (binary between adjacent variables), and C (binary between non-adjacent variables).

Hide code cell source

# Internal implementation.

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import HTML
import numpy as np

# Create a simple problem structure visualization.
# Variables: x0 (dim 2), x1 (dim 2), x2 (dim 2)
# Cost types: A (unary), B (binary adjacent), C (binary non-adjacent)

fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Problem Structure", "Jacobian Sparsity Pattern"),
    column_widths=[0.45, 0.55],
    horizontal_spacing=0.12,
)

# Problem structure layout.
var_x = [0, 2, 4]
var_y = [1, 1, 1]
var_labels = ["x₀", "x₁", "x₂"]
var_dims = [2, 2, 2]  # Tangent dimensions (all 2D).

# C₀ placed above the variables for cleaner layout.
cost_x = [0, 1, 3, 4, 2]
cost_y = [0, 0, 0, 0, 2]
cost_labels = ["A₀", "B₀", "B₁", "A₁", "C₀"]
cost_dims = [2, 2, 2, 2, 2]  # Residual dimensions (all 2).
# Colors: A costs (blue), B costs (orange), C costs (green).
cost_colors = ["#1976d2", "#f57c00", "#f57c00", "#1976d2", "#388e3c"]

# Connections: (cost_idx, var_idx).
# A₀→x₀, B₀→x₀,x₁, B₁→x₁,x₂, A₁→x₂, C₀→x₀,x₂
edges = [(0, 0), (1, 0), (1, 1), (2, 1), (2, 2), (3, 2), (4, 0), (4, 2)]

# Draw edges.
for ci, vi in edges:
    fig.add_trace(
        go.Scatter(
            x=[cost_x[ci], var_x[vi]],
            y=[cost_y[ci], var_y[vi]],
            mode="lines",
            line=dict(color="#888", width=2),
            showlegend=False,
            hoverinfo="skip",
        ),
        row=1,
        col=1,
    )

# Draw variables (circles).
fig.add_trace(
    go.Scatter(
        x=var_x,
        y=var_y,
        mode="markers+text",
        marker=dict(size=40, color="#2196F3", line=dict(width=2, color="white")),
        text=var_labels,
        textposition="middle center",
        textfont=dict(color="white", size=14),
        name="Variables",
        hovertemplate="%{text}<br>dim=%{customdata}<extra></extra>",
        customdata=var_dims,
    ),
    row=1,
    col=1,
)

# Draw costs (squares) - color by cost type.
for i, (cx, cy, label, color, dim) in enumerate(
    zip(cost_x, cost_y, cost_labels, cost_colors, cost_dims)
):
    fig.add_trace(
        go.Scatter(
            x=[cx],
            y=[cy],
            mode="markers+text",
            marker=dict(
                size=35, color=color, symbol="square", line=dict(width=2, color="white")
            ),
            text=[label],
            textposition="middle center",
            textfont=dict(color="white", size=12),
            name=label,
            hovertemplate=f"{label}<br>residual dim={dim}<extra></extra>",
            showlegend=False,
        ),
        row=1,
        col=1,
    )

# Build Jacobian sparsity pattern.
# Order costs by type for the matrix: A₀, A₁, B₀, B₁, C₀
cost_order = [0, 3, 1, 2, 4]  # Reorder to group by type.
ordered_cost_dims = [cost_dims[i] for i in cost_order]
ordered_cost_labels = [cost_labels[i] for i in cost_order]
ordered_cost_colors_idx = [0, 0, 1, 1, 2]  # A, A, B, B, C

n_rows = sum(ordered_cost_dims)  # 10
n_cols = sum(var_dims)  # 6

row_starts = [0] + list(np.cumsum(ordered_cost_dims)[:-1])
col_starts = [0] + list(np.cumsum(var_dims)[:-1])

# Build edge map for reordered costs.
# Original edges: A₀→x₀, B₀→x₀,x₁, B₁→x₁,x₂, A₁→x₂, C₀→x₀,x₂
# Reordered: A₀→x₀, A₁→x₂, B₀→x₀,x₁, B₁→x₁,x₂, C₀→x₀,x₂
reordered_edges = [
    (0, [0]),  # A₀ → x₀
    (1, [2]),  # A₁ → x₂
    (2, [0, 1]),  # B₀ → x₀, x₁
    (3, [1, 2]),  # B₁ → x₁, x₂
    (4, [0, 2]),  # C₀ → x₀, x₂ (non-adjacent!)
]

# Create color matrix for heatmap.
color_matrix = np.zeros((n_rows, n_cols))
color_vals = [1, 1, 2, 2, 3]  # A=1 (blue), B=2 (orange), C=3 (green)
for row_idx, (_, var_indices) in enumerate(reordered_edges):
    r0, r1 = row_starts[row_idx], row_starts[row_idx] + ordered_cost_dims[row_idx]
    for vi in var_indices:
        c0, c1 = col_starts[vi], col_starts[vi] + var_dims[vi]
        color_matrix[r0:r1, c0:c1] = color_vals[row_idx]

# Custom colorscale: white, blue, orange, green.
fig.add_trace(
    go.Heatmap(
        z=color_matrix[::-1],
        colorscale=[[0, "white"], [0.33, "#1976d2"], [0.67, "#f57c00"], [1, "#388e3c"]],
        showscale=False,
        zmin=0,
        zmax=3,
        hovertemplate="row %{y}, col %{x}<extra></extra>",
    ),
    row=1,
    col=2,
)

# Add grid lines for Jacobian.
for i in range(n_rows + 1):
    fig.add_shape(
        type="line",
        x0=-0.5,
        x1=n_cols - 0.5,
        y0=i - 0.5,
        y1=i - 0.5,
        line=dict(color="#ddd", width=1),
        row=1,
        col=2,
    )
for j in range(n_cols + 1):
    fig.add_shape(
        type="line",
        x0=j - 0.5,
        x1=j - 0.5,
        y0=-0.5,
        y1=n_rows - 0.5,
        line=dict(color="#ddd", width=1),
        row=1,
        col=2,
    )

# Add block separators (between cost types).
type_boundaries = [4, 8]  # After A costs, after B costs
for rs in type_boundaries:
    fig.add_shape(
        type="line",
        x0=-0.5,
        x1=n_cols - 0.5,
        y0=n_rows - rs - 0.5,
        y1=n_rows - rs - 0.5,
        line=dict(color="#999", width=2),
        row=1,
        col=2,
    )
for cs in col_starts[1:]:
    fig.add_shape(
        type="line",
        x0=cs - 0.5,
        x1=cs - 0.5,
        y0=-0.5,
        y1=n_rows - 0.5,
        line=dict(color="#2196F3", width=2),
        row=1,
        col=2,
    )

# Add column labels (variables) at bottom.
col_centers = [col_starts[i] + var_dims[i] / 2 - 0.5 for i in range(len(var_dims))]
for i, (cx, label) in enumerate(zip(col_centers, var_labels)):
    fig.add_annotation(
        x=cx,
        y=-1.0,
        text=label,
        showarrow=False,
        font=dict(size=14),
        xref="x2",
        yref="y2",
    )

# Add row labels (costs) at left - show type groups.
type_labels = ["A", "B", "C"]
type_row_ranges = [(0, 4), (4, 8), (8, 10)]
for label, (r0, r1) in zip(type_labels, type_row_ranges):
    ry = n_rows - (r0 + r1) / 2 - 0.5
    fig.add_annotation(
        x=-0.8,
        y=ry,
        text=label,
        showarrow=False,
        font=dict(size=12),
        xref="x2",
        yref="y2",
    )

fig.update_xaxes(showgrid=False, zeroline=False, showticklabels=False, row=1, col=1)
fig.update_yaxes(showgrid=False, zeroline=False, showticklabels=False, row=1, col=1)
fig.update_xaxes(showgrid=False, zeroline=False, showticklabels=False, row=1, col=2)
fig.update_yaxes(showgrid=False, zeroline=False, showticklabels=False, row=1, col=2)

fig.update_layout(
    height=350,
    margin=dict(t=40, b=60, l=60, r=40),
    showlegend=False,
)

HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

In this example, the Jacobian is ~47% zeros, and this ratio grows with problem size. For large SLAM or bundle adjustment problems, Jacobians can be >99% sparse. Exploiting this sparsity is critical for scaling to large problems.

Sparse matrix formats#

jaxls implements three sparse matrix representations. You can select between them using the sparse_mode parameter:

# Default: block-row format (best for CG on GPU).
solution = problem.solve(sparse_mode="blockrow")

# COO format (converts to JAX BCOO).
solution = problem.solve(sparse_mode="coo")

# CSR format (always used for CHOLMOD).
solution = problem.solve(sparse_mode="csr")

BlockRowSparseMatrix#

jaxls’s default and recommended sparse matrix representation uses what we call sparse “block-rows”. Costs of the same type have Jacobians with identical structure: same block sizes, just different values and column positions. jaxls exploits this by batching block-rows from the same cost type.

Using the same example from above:

\[\begin{split} J = \left[\begin{array}{cc|cc|cc} \color{#1976d2}{0.8} & \color{#1976d2}{0.3} & 0 & 0 & 0 & 0 \\ \color{#1976d2}{0.1} & \color{#1976d2}{0.9} & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & \color{#1976d2}{0.6} & \color{#1976d2}{0.2} \\ 0 & 0 & 0 & 0 & \color{#1976d2}{0.4} & \color{#1976d2}{0.7} \\ \hline \color{#f57c00}{1.2} & \color{#f57c00}{0.4} & \color{#f57c00}{0.7} & \color{#f57c00}{0.2} & 0 & 0 \\ \color{#f57c00}{0.5} & \color{#f57c00}{0.6} & \color{#f57c00}{0.3} & \color{#f57c00}{0.8} & 0 & 0 \\ 0 & 0 & \color{#f57c00}{0.9} & \color{#f57c00}{0.1} & \color{#f57c00}{0.4} & \color{#f57c00}{0.7} \\ 0 & 0 & \color{#f57c00}{0.2} & \color{#f57c00}{0.5} & \color{#f57c00}{0.8} & \color{#f57c00}{0.3} \\ \hline \color{#388e3c}{0.3} & \color{#388e3c}{0.9} & 0 & 0 & \color{#388e3c}{0.5} & \color{#388e3c}{0.1} \\ \color{#388e3c}{0.7} & \color{#388e3c}{0.4} & 0 & 0 & \color{#388e3c}{0.2} & \color{#388e3c}{0.8} \end{array}\right] \begin{array}{l} \leftarrow A_0 \\ \\ \leftarrow A_1 \\ \\ \leftarrow B_0 \\ \\ \leftarrow B_1 \\ \\ \leftarrow C_0 \\ \end{array} \end{split}\]

Each cost type becomes one SparseBlockRow:

@jdc.pytree_dataclass
class BlockRowSparseMatrix:
    block_rows: tuple[SparseBlockRow, ...]  # One per cost type.

@jdc.pytree_dataclass
class SparseBlockRow:
    num_cols: jdc.Static[int]                      # Total matrix width.
    block_num_cols: jdc.Static[tuple[int, ...]]    # Block widths.
    start_cols: tuple[jax.Array, ...]              # Column positions (traced).
    blocks_concat: jax.Array                       # Jacobian values (traced).

Cost type

blocks_concat shape

block_num_cols

start_cols

A

(2, 2, 2)

(2,)

[[0], [4]]

B

(2, 2, 4)

(2, 2)

[[0, 2], [2, 4]]

C

(1, 2, 4)

(2, 2)

[[0, 4]]

The zeros are never stored. start_cols records where each block lands in the full matrix.

The block-row format is more GPU-friendly than traditional sparse formats like COO and CSR, which index individual nonzero entries. This leads to irregular memory access patterns and poor cache utilization on GPUs, where performance depends on coalesced memory access and high arithmetic intensity. The block-row format sidesteps these issues:

  • Dense matmul kernels. Jacobian blocks from each cost type are stacked into dense arrays (blocks_concat). Matrix-vector products become batched dense matmuls, which map directly to optimized GPU kernels (cuBLAS, etc.).

  • Structured indexing. While start_cols involves dynamic indexing for gather/scatter operations, the block structure (number of blocks, their sizes) is static. This is more regular than arbitrary element-wise sparse indexing.

  • Efficient preconditioning. Diagonal blocks of \(J^T J\) can be accumulated without materializing the full matrix, enabling Block-Jacobi preconditioning with minimal overhead.

SparseCooMatrix and SparseCsrMatrix#

jaxls also supports standard sparse formats for compatibility with JAX’s native sparse operations and external solvers like CHOLMOD.

SparseCooMatrix (coordinate format):

Stores each nonzero as a (row, col, value) triple. Simple but inefficient for matrix operations since entries aren’t ordered.

@jdc.pytree_dataclass
class SparseCooCoordinates:
    rows: jax.Array   # Row indices, shape (nnz,).
    cols: jax.Array   # Column indices, shape (nnz,).
    shape: tuple[int, int]

@jdc.pytree_dataclass
class SparseCooMatrix:
    values: jax.Array              # Nonzero values, shape (nnz,).
    coords: SparseCooCoordinates   # Row and column indices.

Use sparse_mode="coo" to convert to JAX’s native BCOO format for sparse linear algebra operations.

SparseCsrMatrix (compressed sparse row):

Groups entries by row for efficient row-wise access. Each row’s column indices and values are stored contiguously, with indptr marking where each row starts.

@jdc.pytree_dataclass  
class SparseCsrCoordinates:
    indices: jax.Array   # Column indices, shape (nnz,).
    indptr: jax.Array    # Row pointers, shape (nrows + 1,).
    shape: tuple[int, int]

@jdc.pytree_dataclass
class SparseCsrMatrix:
    values: jax.Array              # Nonzero values, shape (nnz,).
    coords: SparseCsrCoordinates   # CSR index structure.

CSR is required for CHOLMOD (linear_solver="cholmod"), which performs sparse Cholesky factorization on CPU. Conversion to CSR is done automatically when using CHOLMOD, but you can also explicitly request it with sparse_mode="csr".

Linear solvers#

The Gauss-Newton update requires solving \((J^T J + \lambda I) \Delta x = -J^T r\). jaxls offers three linear solvers for this system:

Conjugate gradient (default)#

solution = problem.solve(linear_solver="conjugate_gradient")

This is the default and generally recommended solver.

  • Best for: Large sparse problems, GPU acceleration

  • Pros: Memory-efficient (never forms \(J^T J\)), JAX-native, works on GPU

  • Cons: May require many iterations, sensitive to conditioning

Uses inexact Newton via Eisenstat-Walker adaptive tolerances.

Dense Cholesky#

solution = problem.solve(linear_solver="dense_cholesky")
  • Best for: Small problems (<1000 variables)

  • Pros: Direct solve, no iterations, numerically stable

  • Cons: Forms dense \(J^T J\) matrix, \(O(n^3)\) complexity

Sparse Cholesky (CHOLMOD)#

solution = problem.solve(linear_solver="cholmod")
  • Best for: Large sparse problems on CPU

  • Pros: Direct solve, exploits sparsity, sparse fill-reducing ordering

  • Cons: CPU only, requires SuiteSparse installation

jaxls caches the symbolic factorization (sparsity pattern analysis), so repeated solves with the same structure are fast.

Preconditioning#

For conjugate gradient, preconditioning accelerates convergence by transforming the linear system. Instead of solving \(Ax = b\) directly, we solve \(M^{-1}Ax = M^{-1}b\) where the preconditioner \(M \approx A\) is easy to invert.

For the normal equations \(A = J^T J\), jaxls provides two preconditioners:

Block Jacobi (default)#

solution = problem.solve(
    linear_solver=jaxls.ConjugateGradientConfig(
        preconditioner="block_jacobi",
    ),
)

Uses \(M = \text{blockdiag}(J^T J)\), where blocks correspond to individual variables. For variable \(i\) with Jacobian columns \(J_i\):

\[M_i = J_i^T J_i\]

Effective when off-diagonal coupling between variables is weak.

Point Jacobi#

solution = problem.solve(
    linear_solver=jaxls.ConjugateGradientConfig(
        preconditioner="point_jacobi",
    ),
)

Uses \(M = \text{diag}(J^T J)\), a scalar per tangent dimension:

\[M_{ii} = \sum_j J_{ji}^2\]

Cheaper to compute but less effective than block Jacobi.