Custom Jacobians#

In nonlinear least squares, the Jacobian matrix \(J\) contains partial derivatives of residuals with respect to variables. It’s used to approximate the Hessian as \(J^T J\) for Gauss-Newton updates. By default, jaxls computes Jacobians automatically via JAX’s autodiff. This guide shows how to provide analytical Jacobians instead.

Features used:

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import time

import jax
import jax.numpy as jnp
import jaxls

When to use custom Jacobians#

JAX’s automatic differentiation is powerful and efficient for most use cases. However, there are situations where providing analytical Jacobians can be beneficial:

  1. Known closed-form solutions. When the Jacobian has a simple analytical form that is faster to compute than autodiff.

  2. Numerical stability. When the analytical form is more numerically stable than the autodiff computation.

  3. Reusing intermediates. When the residual computation produces intermediate values that can be reused for the Jacobian.

Problem setup: 2D point fitting#

We’ll demonstrate custom Jacobians with a simple 2D point fitting problem. Given a set of target points, we want to find the optimal location that minimizes the sum of squared distances.

For a point \(p \in \mathbb{R}^2\) and target \(t_i \in \mathbb{R}^2\), the residual is the Euclidean distance: $\(r_i(p) = \|p - t_i\|\)$

The Jacobian of the distance with respect to the point is: $\(\frac{\partial r_i}{\partial p} = \frac{(p - t_i)^T}{\|p - t_i\|}\)$

This is a simple 1x2 row vector (residual dimension 1, tangent dimension 2).

# Define a 2D point variable.
class Point2DVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(2)):
    """A 2D point variable."""


# Generate random target points.
num_targets = 100
key = jax.random.PRNGKey(42)
targets = jax.random.normal(key, (num_targets, 2))

print(f"Target points: {targets.shape}")
Target points: (100, 2)

Basic usage: jac_custom_fn#

The simplest way to provide a custom Jacobian is via the jac_custom_fn parameter. This function takes the same arguments as the residual function and returns a 2D Jacobian matrix.

Important: The Jacobian shape must be (residual_dim, sum_of_tangent_dims_of_variables). For a single Point2DVar, this is (1, 2).

def distance_jacobian(
    vals: jaxls.VarValues,
    var: Point2DVar,
    target: jax.Array,
) -> jax.Array:
    """Analytical Jacobian of distance residual.

    Returns:
        Jacobian matrix of shape (1, 2).
    """
    point = vals[var]
    diff = point - target
    dist = jnp.linalg.norm(diff)
    # Jacobian is (p - t)^T / ||p - t||, reshaped to (1, 2).
    return (diff / dist).reshape(1, 2)


@jaxls.Cost.factory(jac_custom_fn=distance_jacobian)
def distance_cost_custom(
    vals: jaxls.VarValues,
    var: Point2DVar,
    target: jax.Array,
) -> jax.Array:
    """Distance residual with custom Jacobian."""
    point = vals[var]
    return jnp.linalg.norm(point - target).reshape(1)

For comparison, here’s the same cost using autodiff:

@jaxls.Cost.factory
def distance_cost_autodiff(
    vals: jaxls.VarValues,
    var: Point2DVar,
    target: jax.Array,
) -> jax.Array:
    """Distance residual with autodiff Jacobian."""
    point = vals[var]
    return jnp.linalg.norm(point - target).reshape(1)

With cache: jac_custom_with_cache_fn#

When computing the residual produces intermediate values that can be reused for the Jacobian, use jac_custom_with_cache_fn. The residual function must return a tuple of (residual, cache), and the Jacobian function receives this cache as its second argument.

In our distance example, both the residual and Jacobian need diff = point - target and dist = ||diff||. We can compute these once and cache them:

from typing import NamedTuple


class DistanceCache(NamedTuple):
    """Cache for distance computation."""

    diff: jax.Array  # point - target.
    dist: jax.Array  # ||diff||.


def distance_jacobian_with_cache(
    vals: jaxls.VarValues,
    cache: DistanceCache,
    var: Point2DVar,
    target: jax.Array,
) -> jax.Array:
    """Jacobian using cached intermediate values.

    Args:
        vals: Variable values (not used since we have cache).
        cache: Cached diff and dist from residual computation.
        var: The point variable.
        target: Target point.

    Returns:
        Jacobian matrix of shape (1, 2).
    """
    # Reuse cached values instead of recomputing.
    return (cache.diff / cache.dist).reshape(1, 2)


@jaxls.Cost.factory(jac_custom_with_cache_fn=distance_jacobian_with_cache)
def distance_cost_cached(
    vals: jaxls.VarValues,
    var: Point2DVar,
    target: jax.Array,
) -> tuple[jax.Array, DistanceCache]:
    """Distance residual that caches intermediates for Jacobian.

    Returns:
        Tuple of (residual, cache) - the cache is passed to the Jacobian function.
    """
    point = vals[var]
    diff = point - target
    dist = jnp.linalg.norm(diff)
    cache = DistanceCache(diff=diff, dist=dist)
    return dist.reshape(1), cache

Jacobian shape requirements#

The Jacobian must have shape (residual_dim, sum_of_tangent_dims_of_variables). Let’s verify our Jacobians have the correct shape:

# Create a test point and target.
test_var = Point2DVar(id=0)
test_point = jnp.array([1.0, 2.0])
test_target = jnp.array([0.0, 0.0])
test_vals = jaxls.VarValues.make([test_var.with_value(test_point)])

# Check Jacobian shape from custom function.
jac = distance_jacobian(test_vals, test_var, test_target)
print(f"Jacobian shape: {jac.shape}")
print("Expected: (residual_dim=1, tangent_dim=2)")

# Verify against autodiff.
jac_autodiff = jax.jacrev(lambda p: jnp.linalg.norm(p - test_target).reshape(1))(
    test_point
)
print(f"\nCustom Jacobian:\n{jac}")
print(f"Autodiff Jacobian:\n{jac_autodiff}")
print(f"Match: {jnp.allclose(jac, jac_autodiff)}")
Jacobian shape: (1, 2)
Expected: (residual_dim=1, tangent_dim=2)

Custom Jacobian:
[[0.4472136 0.8944272]]
Autodiff Jacobian:
[[0.4472136 0.8944272]]
Match: True

Solving the optimization problem#

Let’s verify all three approaches produce the same result:

def solve_with_cost_factory(cost_factory, name: str) -> jax.Array:
    """Solve the point fitting problem with a given cost factory."""
    var = Point2DVar(id=0)

    # Create batched costs for all targets.
    costs = [
        cost_factory(
            Point2DVar(id=jnp.zeros(num_targets, dtype=jnp.int32)),
            targets,
        )
    ]

    # Initial guess away from the solution.
    initial_vals = jaxls.VarValues.make([var.with_value(jnp.array([5.0, 5.0]))])

    # Solve.
    problem = jaxls.LeastSquaresProblem(costs, [var]).analyze()
    solution = problem.solve(initial_vals, verbose=False)

    result = solution[var]
    print(f"{name}: point = [{result[0]:.6f}, {result[1]:.6f}]")
    return result


result_autodiff = solve_with_cost_factory(distance_cost_autodiff, "Autodiff")
result_custom = solve_with_cost_factory(distance_cost_custom, "Custom ")
result_cached = solve_with_cost_factory(distance_cost_cached, "Cached ")

# The optimal point should be close to the mean of targets.
# (for sum of squared distances, the minimum is at the geometric median,.
# which is close to the mean for normally distributed points)
print(f"\nTarget mean: [{targets[:, 0].mean():.6f}, {targets[:, 1].mean():.6f}]")
INFO     | Building optimization problem with 100 terms and 1 variables: 100 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 100 costs, 1 variables each: distance_cost_autodiff
Autodiff: point = [-0.016823, 0.043048]
INFO     | Building optimization problem with 100 terms and 1 variables: 100 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 100 costs, 1 variables each: distance_cost_custom
Custom : point = [-0.016822, 0.043048]
INFO     | Building optimization problem with 100 terms and 1 variables: 100 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 100 costs, 1 variables each: distance_cost_cached
Cached : point = [-0.016822, 0.043048]

Target mean: [-0.024155, 0.041008]

Performance comparison#

Let’s compare the timing of autodiff vs custom Jacobians. Note that for this simple example, the difference may be small or even favor autodiff due to JAX’s optimizations. Custom Jacobians are most beneficial for:

  • Complex functions where the analytical Jacobian is simpler

  • Cases where intermediate values can be heavily reused

  • Very high-dimensional problems where autodiff has significant overhead

def benchmark_solver(cost_factory, name: str, num_runs: int = 20) -> float:
    """Benchmark a solver configuration.

    Returns:
        Minimum time per solve in milliseconds.
    """
    var = Point2DVar(id=0)

    costs = [
        cost_factory(
            Point2DVar(id=jnp.zeros(num_targets, dtype=jnp.int32)),
            targets,
        )
    ]

    initial_vals = jaxls.VarValues.make([var.with_value(jnp.array([5.0, 5.0]))])
    problem = jaxls.LeastSquaresProblem(costs, [var]).analyze()

    # Warmup (JIT compilation).
    solution = problem.solve(initial_vals, verbose=False)
    jax.block_until_ready(solution[var])

    # Timed runs.
    times = []
    for _ in range(num_runs):
        start = time.perf_counter()
        solution = problem.solve(initial_vals, verbose=False)
        jax.block_until_ready(solution[var])  # Important: wait for async execution!
        times.append(time.perf_counter() - start)

    min_time = min(times) * 1000  # Convert to ms.
    print(f"{name}: {min_time:.3f} ms (min of {num_runs} runs)")
    return min_time


print("Benchmarking solve times...\n")
t_autodiff = benchmark_solver(distance_cost_autodiff, "Autodiff      ")
t_custom = benchmark_solver(distance_cost_custom, "Custom Jacobian")
t_cached = benchmark_solver(distance_cost_cached, "With Cache     ")
Benchmarking solve times...

INFO     | Building optimization problem with 100 terms and 1 variables: 100 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 100 costs, 1 variables each: distance_cost_autodiff
Autodiff      : 0.672 ms (min of 20 runs)
INFO     | Building optimization problem with 100 terms and 1 variables: 100 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 100 costs, 1 variables each: distance_cost_custom
Custom Jacobian: 0.618 ms (min of 20 runs)
INFO     | Building optimization problem with 100 terms and 1 variables: 100 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 100 costs, 1 variables each: distance_cost_cached
With Cache     : 0.749 ms (min of 20 runs)

Multiple variables#

When a cost involves multiple variables, the Jacobian concatenates their tangent dimensions in the order they appear as arguments. For example, a cost with two Point2DVar variables would have a Jacobian of shape (residual_dim, 4) (2 + 2 tangent dims).

def two_point_jacobian(
    vals: jaxls.VarValues,
    var_a: Point2DVar,
    var_b: Point2DVar,
) -> jax.Array:
    """Jacobian for distance between two points.

    The residual is ||p_a - p_b||, and the Jacobian has shape (1, 4)
    with columns [d/dp_a, d/dp_b].
    """
    p_a = vals[var_a]
    p_b = vals[var_b]
    diff = p_a - p_b
    dist = jnp.linalg.norm(diff)
    # d/dp_a = (p_a - p_b) / ||p_a - p_b||.
    # d/dp_b = -(p_a - p_b) / ||p_a - p_b||.
    jac_a = diff / dist
    jac_b = -diff / dist
    return jnp.concatenate([jac_a, jac_b]).reshape(1, 4)


@jaxls.Cost.factory(jac_custom_fn=two_point_jacobian)
def two_point_distance(
    vals: jaxls.VarValues,
    var_a: Point2DVar,
    var_b: Point2DVar,
) -> jax.Array:
    """Distance between two points."""
    return jnp.linalg.norm(vals[var_a] - vals[var_b]).reshape(1)


# Verify the multi-variable Jacobian shape.
var_a = Point2DVar(id=0)
var_b = Point2DVar(id=1)
test_vals = jaxls.VarValues.make(
    [
        var_a.with_value(jnp.array([1.0, 0.0])),
        var_b.with_value(jnp.array([0.0, 1.0])),
    ]
)

jac = two_point_jacobian(test_vals, var_a, var_b)
print(f"Two-variable Jacobian shape: {jac.shape}")
print(f"Jacobian:\n{jac}")
Two-variable Jacobian shape: (1, 4)
Jacobian:
[[ 0.70710677 -0.70710677 -0.70710677  0.70710677]]

Important notes#

  1. Correctness: Custom Jacobians bypass autodiff entirely. If your Jacobian is incorrect, the solver may converge to the wrong solution or fail to converge. Always verify against autodiff during development.

  2. Shape requirements: The Jacobian must be a 2D array with shape (residual_dim, total_tangent_dim). The tangent dimensions are concatenated in the order variables appear as arguments.

  3. Caching: Use jac_custom_with_cache_fn when the residual computation produces expensive intermediate values that the Jacobian can reuse. This avoids duplicate computation.