Sparse matrices#
jaxls solves nonlinear least squares using Levenberg-Marquardt, which is based on the Gauss-Newton method.
This page discusses:
The Gauss-Newton approximation and why Jacobian structure matters
jaxls’s
BlockRowSparseMatrixrepresentation, which is designed to be GPU-friendlyLinear solvers and preconditioning strategies
The Gauss-Newton approximation#
Newton’s method for minimizing \(f(x) = \frac{1}{2}\sum_i \|r_i(x)\|^2\) requires the Hessian \(\nabla^2 f\), which is expensive to compute. Gauss-Newton instead approximates the Hessian using only first-order information:
where \(J = \partial r / \partial x\) is the Jacobian of the stacked residual vector \(r(x) = [r_0(x)^\top, r_1(x)^\top, \ldots]^\top\) with respect to the optimization variables.
Why Jacobians are sparse#
Residual terms in nonlinear least squares often depend on only a subset of the optimization variables. For example, a pairwise cost between two poses only depends on those two poses, not the hundreds of other poses in a SLAM problem. This produces sparse Jacobians: row block \(i\) of \(J\) has non-zeros only in columns corresponding to the variables that \(r_i\) depends on.
This sparsity pattern corresponds to the adjacency matrix of a bipartite factor graph:
Variables (circles): The parameters we’re optimizing
Costs (squares): Residual functions that connect to their dependent variables
Consider an optimization problem with variables \(x_0\), \(x_1\), \(x_2\) and three types of costs: A (unary), B (binary between adjacent variables), and C (binary between non-adjacent variables).
In this example, the Jacobian is ~47% zeros, and this ratio grows with problem size. For large SLAM or bundle adjustment problems, Jacobians can be >99% sparse. Exploiting this sparsity is critical for scaling to large problems.
Sparse matrix formats#
jaxls implements three sparse matrix representations. You can select between them
using the sparse_mode parameter:
# Default: block-row format (best for CG on GPU).
solution = problem.solve(sparse_mode="blockrow")
# COO format (converts to JAX BCOO).
solution = problem.solve(sparse_mode="coo")
# CSR format (always used for CHOLMOD).
solution = problem.solve(sparse_mode="csr")
BlockRowSparseMatrix#
jaxls’s default and recommended sparse matrix representation uses what we call sparse “block-rows”. Costs of the same type have Jacobians with identical structure: same block sizes, just different values and column positions. jaxls exploits this by batching block-rows from the same cost type.
Using the same example from above:
Each cost type becomes one SparseBlockRow:
@jdc.pytree_dataclass
class BlockRowSparseMatrix:
block_rows: tuple[SparseBlockRow, ...] # One per cost type.
@jdc.pytree_dataclass
class SparseBlockRow:
num_cols: jdc.Static[int] # Total matrix width.
block_num_cols: jdc.Static[tuple[int, ...]] # Block widths.
start_cols: tuple[jax.Array, ...] # Column positions (traced).
blocks_concat: jax.Array # Jacobian values (traced).
Cost type |
|
|
|
|---|---|---|---|
A |
|
|
|
B |
|
|
|
C |
|
|
|
The zeros are never stored. start_cols records where each block lands in the full matrix.
The block-row format is more GPU-friendly than traditional sparse formats like COO and CSR, which index individual nonzero entries. This leads to irregular memory access patterns and poor cache utilization on GPUs, where performance depends on coalesced memory access and high arithmetic intensity. The block-row format sidesteps these issues:
Dense matmul kernels. Jacobian blocks from each cost type are stacked into dense arrays (
blocks_concat). Matrix-vector products become batched dense matmuls, which map directly to optimized GPU kernels (cuBLAS, etc.).Structured indexing. While
start_colsinvolves dynamic indexing for gather/scatter operations, the block structure (number of blocks, their sizes) is static. This is more regular than arbitrary element-wise sparse indexing.Efficient preconditioning. Diagonal blocks of \(J^T J\) can be accumulated without materializing the full matrix, enabling Block-Jacobi preconditioning with minimal overhead.
SparseCooMatrix and SparseCsrMatrix#
jaxls also supports standard sparse formats for compatibility with JAX’s native sparse operations and external solvers like CHOLMOD.
SparseCooMatrix (coordinate format):
Stores each nonzero as a (row, col, value) triple. Simple but inefficient for
matrix operations since entries aren’t ordered.
@jdc.pytree_dataclass
class SparseCooCoordinates:
rows: jax.Array # Row indices, shape (nnz,).
cols: jax.Array # Column indices, shape (nnz,).
shape: tuple[int, int]
@jdc.pytree_dataclass
class SparseCooMatrix:
values: jax.Array # Nonzero values, shape (nnz,).
coords: SparseCooCoordinates # Row and column indices.
Use sparse_mode="coo" to convert to JAX’s native BCOO format for sparse
linear algebra operations.
SparseCsrMatrix (compressed sparse row):
Groups entries by row for efficient row-wise access. Each row’s column indices
and values are stored contiguously, with indptr marking where each row starts.
@jdc.pytree_dataclass
class SparseCsrCoordinates:
indices: jax.Array # Column indices, shape (nnz,).
indptr: jax.Array # Row pointers, shape (nrows + 1,).
shape: tuple[int, int]
@jdc.pytree_dataclass
class SparseCsrMatrix:
values: jax.Array # Nonzero values, shape (nnz,).
coords: SparseCsrCoordinates # CSR index structure.
CSR is required for CHOLMOD (linear_solver="cholmod"), which performs sparse
Cholesky factorization on CPU. Conversion to CSR is done automatically when using
CHOLMOD, but you can also explicitly request it with sparse_mode="csr".
Linear solvers#
The Gauss-Newton update requires solving \((J^T J + \lambda I) \Delta x = -J^T r\). jaxls offers three linear solvers for this system:
Conjugate gradient (default)#
solution = problem.solve(linear_solver="conjugate_gradient")
This is the default and generally recommended solver.
Best for: Large sparse problems, GPU acceleration
Pros: Memory-efficient (never forms \(J^T J\)), JAX-native, works on GPU
Cons: May require many iterations, sensitive to conditioning
Uses inexact Newton via Eisenstat-Walker adaptive tolerances.
Dense Cholesky#
solution = problem.solve(linear_solver="dense_cholesky")
Best for: Small problems (<1000 variables)
Pros: Direct solve, no iterations, numerically stable
Cons: Forms dense \(J^T J\) matrix, \(O(n^3)\) complexity
Sparse Cholesky (CHOLMOD)#
solution = problem.solve(linear_solver="cholmod")
Best for: Large sparse problems on CPU
Pros: Direct solve, exploits sparsity, sparse fill-reducing ordering
Cons: CPU only, requires SuiteSparse installation
jaxls caches the symbolic factorization (sparsity pattern analysis), so repeated solves with the same structure are fast.
Preconditioning#
For conjugate gradient, preconditioning accelerates convergence by transforming the linear system. Instead of solving \(Ax = b\) directly, we solve \(M^{-1}Ax = M^{-1}b\) where the preconditioner \(M \approx A\) is easy to invert.
For the normal equations \(A = J^T J\), jaxls provides two preconditioners:
Block Jacobi (default)#
solution = problem.solve(
linear_solver=jaxls.ConjugateGradientConfig(
preconditioner="block_jacobi",
),
)
Uses \(M = \text{blockdiag}(J^T J)\), where blocks correspond to individual variables. For variable \(i\) with Jacobian columns \(J_i\):
Effective when off-diagonal coupling between variables is weak.
Point Jacobi#
solution = problem.solve(
linear_solver=jaxls.ConjugateGradientConfig(
preconditioner="point_jacobi",
),
)
Uses \(M = \text{diag}(J^T J)\), a scalar per tangent dimension:
Cheaper to compute but less effective than block Jacobi.