Tips and gotchas#

Practical guidance for getting the most out of jaxls.

Features used:

  • Batched construction for efficient problem setup

  • Residual vector structure for proper Hessian approximation

  • Jacobian mode selection (jac_mode)

  • Linear solver selection (linear_solver)

  • Debugging with return_summary

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import time

import jax
import jax.numpy as jnp
import jaxls

Model problem: circle fitting#

We’ll use circle fitting as a running example. Given noisy 2D points, we want to find the center \((c_x, c_y)\) and radius \(r\) of the best-fit circle.

class CircleVar(
    jaxls.Var[jax.Array], default_factory=lambda: jnp.array([0.0, 0.0, 1.0])
):
    """Circle parameters: (center_x, center_y, radius)."""


# Generate points on a circle (no noise for cleaner demonstration).
def generate_circle_points(
    n_points: int,
    true_center: tuple[float, float],
    true_radius: float,
) -> jax.Array:
    """Generate points sampled from a circle."""
    angles = jnp.linspace(0, 2 * jnp.pi, n_points, endpoint=False)
    points = jnp.stack(
        [
            true_center[0] + true_radius * jnp.cos(angles),
            true_center[1] + true_radius * jnp.sin(angles),
        ],
        axis=-1,
    )
    return points


# Ground truth circle.
TRUE_CENTER = (2.0, 3.0)
TRUE_RADIUS = 5.0
N_POINTS = 20

points = generate_circle_points(N_POINTS, TRUE_CENTER, TRUE_RADIUS)
print(
    f"Generated {N_POINTS} points on circle at {TRUE_CENTER} with radius {TRUE_RADIUS}"
)
Generated 20 points on circle at (2.0, 3.0) with radius 5.0

Hide code cell source

import plotly.graph_objects as go
from IPython.display import HTML

# Generate true circle for reference.
theta = jnp.linspace(0, 2 * jnp.pi, 100)
true_circle_x = TRUE_CENTER[0] + TRUE_RADIUS * jnp.cos(theta)
true_circle_y = TRUE_CENTER[1] + TRUE_RADIUS * jnp.sin(theta)

fig = go.Figure()

# Data points.
fig.add_trace(
    go.Scatter(
        x=points[:, 0],
        y=points[:, 1],
        mode="markers",
        marker=dict(size=10, color="#FF00FF"),
        name="Data points",
    )
)

# True circle.
fig.add_trace(
    go.Scatter(
        x=true_circle_x,
        y=true_circle_y,
        mode="lines",
        line=dict(color="gray", width=2, dash="dash"),
        name=f"True circle (r={TRUE_RADIUS:.1f})",
    )
)

# True center.
fig.add_trace(
    go.Scatter(
        x=[TRUE_CENTER[0]],
        y=[TRUE_CENTER[1]],
        mode="markers",
        marker=dict(size=10, color="gray", symbol="x"),
        name="True center",
    )
)

fig.update_layout(
    xaxis_title="x",
    yaxis_title="y",
    xaxis=dict(scaleanchor="y", scaleratio=1),
    height=350,
    margin=dict(t=20, b=40, l=60, r=40),
    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
)
HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

Batched vs naive construction#

When building optimization problems, how you construct costs impacts setup time.

Naive approach: Create individual cost objects in a Python loop. Simple but slow for large problems.

Batched approach: Pass arrays of variable IDs and data. Creates all costs in one call, much faster.

@jaxls.Cost.factory
def circle_residual(
    vals: jaxls.VarValues,
    circle_var: CircleVar,
    point: jax.Array,
) -> jax.Array:
    """Residual for fitting a circle to a point."""
    params = vals[circle_var]
    center = params[:2]
    r = params[2]
    diff = point - center
    dist = jnp.sqrt(jnp.sum(diff**2) + 1e-6)
    direction = diff / dist
    # 2D residual: error vector from closest circle point to actual point.
    return (dist - r) * direction
# Naive construction: loop creating individual costs.
start = time.time()
costs_naive = [circle_residual(CircleVar(id=0), points[i]) for i in range(N_POINTS)]
problem_naive = jaxls.LeastSquaresProblem(costs_naive, [CircleVar(id=0)]).analyze()
elapsed_naive = time.time() - start
print(f"Naive construction: {len(costs_naive)} cost objects, {elapsed_naive:.3f}s")
INFO     | Building optimization problem with 20 terms and 1 variables: 20 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 20 costs, 1 variables each: circle_residual
Naive construction: 20 cost objects, 0.678s
# Batched construction: pass arrays directly.
start = time.time()
# Use batched variable IDs (array of zeros = all point residuals reference the same circle variable).
costs_batched = [
    circle_residual(CircleVar(id=jnp.zeros(N_POINTS, dtype=jnp.int32)), points)
]
problem_batched = jaxls.LeastSquaresProblem(costs_batched, [CircleVar(id=0)]).analyze()
elapsed_batched = time.time() - start
print(
    f"Batched construction: {len(costs_batched)} cost object(s), {elapsed_batched:.3f}s"
)
print(f"Speedup: {elapsed_naive / elapsed_batched:.1f}x")
INFO     | Building optimization problem with 20 terms and 1 variables: 20 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 20 costs, 1 variables each: circle_residual
Batched construction: 1 cost object(s), 0.102s
Speedup: 6.7x

jaxls’s problem analysis step ensures that both approaches produce identical results and solve times, but batched construction is faster because it avoids Python loop and vectorization analysis overhead.

# Verify both give the same solution.
solution_naive = problem_naive.solve(verbose=False)
jax.block_until_ready(solution_naive)

solution_batched = problem_batched.solve(verbose=False)
jax.block_until_ready(solution_batched)

params_naive = solution_naive[CircleVar(id=0)]
params_batched = solution_batched[CircleVar(id=0)]

print(
    f"Naive solution:   center=({params_naive[0]:.3f}, {params_naive[1]:.3f}), radius={params_naive[2]:.3f}"
)
print(
    f"Batched solution: center=({params_batched[0]:.3f}, {params_batched[1]:.3f}), radius={params_batched[2]:.3f}"
)
print(f"Max difference: {float(jnp.max(jnp.abs(params_naive - params_batched))):.2e}")
Naive solution:   center=(2.000, 3.000), radius=5.000
Batched solution: center=(2.000, 3.000), radius=5.000
Max difference: 0.00e+00

Residual vector dimension#

Gauss-Newton approximates the Hessian as \(J^T J\), where \(J\) is the residual Jacobian (see Sparse matrices for details). The rank of this approximation is limited by the number of residual dimensions, so higher-dimensional residuals generally lead to faster convergence.

For circle fitting, consider two formulations:

  • 1D residual: [dist - r], the signed distance error (shape (1,))

  • 2D residual: (dist - r) * direction, the error vector pointing from circle to point (shape (2,))

Both have the same cost when squared, but the 2D version has a 2×3 Jacobian (rank 2) vs 1×3 (rank 1).

# 1D residual: scalar distance error.
@jaxls.Cost.factory
def circle_residual_1d(
    vals: jaxls.VarValues,
    circle_var: CircleVar,
    point: jax.Array,
) -> jax.Array:
    """Returns 1D residual: signed distance to circle."""
    params = vals[circle_var]
    center = params[:2]
    r = params[2]
    diff = point - center
    dist = jnp.sqrt(jnp.sum(diff**2) + 1e-6)
    return jnp.array([dist - r])  # Shape (1,).


# 2D residual: error vector from circle to point.
@jaxls.Cost.factory
def circle_residual_2d(
    vals: jaxls.VarValues,
    circle_var: CircleVar,
    point: jax.Array,
) -> jax.Array:
    """Returns 2D residual: vector from closest circle point to actual point."""
    params = vals[circle_var]
    center = params[:2]
    r = params[2]
    diff = point - center
    dist = jnp.sqrt(jnp.sum(diff**2) + 1e-6)
    direction = diff / dist
    # Error vector: (dist - r) * direction.
    return (dist - r) * direction  # Shape (2,).
# Compare convergence.
circle_var = CircleVar(id=0)
initial_guess = jnp.array([0.0, 0.0, 1.0])
batched_var = CircleVar(id=jnp.zeros(N_POINTS, dtype=jnp.int32))

# 1D residual problem.
problem_1d = jaxls.LeastSquaresProblem(
    [circle_residual_1d(batched_var, points)], [circle_var]
).analyze()

# 2D residual problem.
problem_2d = jaxls.LeastSquaresProblem(
    [circle_residual_2d(batched_var, points)], [circle_var]
).analyze()

# Solve both.
initial_vals = jaxls.VarValues.make([circle_var.with_value(initial_guess)])

sol_1d, summary_1d = problem_1d.solve(initial_vals, verbose=False, return_summary=True)
sol_2d, summary_2d = problem_2d.solve(initial_vals, verbose=False, return_summary=True)

params_1d = sol_1d[circle_var]
params_2d = sol_2d[circle_var]

print(
    f"Initial guess: center=({initial_guess[0]:.1f}, {initial_guess[1]:.1f}), radius={initial_guess[2]:.1f}"
)
print(
    f"Ground truth:  center=({TRUE_CENTER[0]:.1f}, {TRUE_CENTER[1]:.1f}), radius={TRUE_RADIUS:.1f}"
)
print()
print(
    f"1D residual: {int(summary_1d.iterations)} iterations -> center=({params_1d[0]:.3f}, {params_1d[1]:.3f}), radius={params_1d[2]:.3f}"
)
print(
    f"2D residual: {int(summary_2d.iterations)} iterations -> center=({params_2d[0]:.3f}, {params_2d[1]:.3f}), radius={params_2d[2]:.3f}"
)
INFO     | Building optimization problem with 20 terms and 1 variables: 20 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 20 costs, 1 variables each: circle_residual_1d
INFO     | Building optimization problem with 20 terms and 1 variables: 20 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 20 costs, 1 variables each: circle_residual_2d
Initial guess: center=(0.0, 0.0), radius=1.0
Ground truth:  center=(2.0, 3.0), radius=5.0

1D residual: 4 iterations -> center=(2.000, 3.000), radius=5.000
2D residual: 5 iterations -> center=(2.000, 3.000), radius=5.000

This principle applies broadly. For reprojection error in bundle adjustment, return the 2D pixel error vector directly rather than computing jnp.linalg.norm(error). The 2D residual provides a rank-2 Jacobian per observation, improving the Hessian approximation.

Jacobian mode selection#

jaxls uses autodiff to compute Jacobians. The jac_mode parameter controls whether to use forward-mode or reverse-mode differentiation:

  • "auto" (default): Automatically chooses based on dimensions

  • "forward": Rule of thumb: better when residual_dim > tangent_dim

  • "reverse": Rule of thumb: better when tangent_dim > residual_dim

These are rough heuristics, not strict rules. In practice, it can be worth trying both modes to see which is faster for your specific problem. The "auto" choice is usually reasonable, but manual selection may help in some cases.

@jaxls.Cost.factory(jac_mode="forward")
def circle_residual_forward(
    vals: jaxls.VarValues,
    circle_var: CircleVar,
    point: jax.Array,
) -> jax.Array:
    """Circle residual with forward-mode Jacobian."""
    params = vals[circle_var]
    center = params[:2]
    r = params[2]
    diff = point - center
    dist = jnp.sqrt(jnp.sum(diff**2) + 1e-6)
    direction = diff / dist
    return (dist - r) * direction


@jaxls.Cost.factory(jac_mode="reverse")
def circle_residual_reverse(
    vals: jaxls.VarValues,
    circle_var: CircleVar,
    point: jax.Array,
) -> jax.Array:
    """Circle residual with reverse-mode Jacobian."""
    params = vals[circle_var]
    center = params[:2]
    r = params[2]
    diff = point - center
    dist = jnp.sqrt(jnp.sum(diff**2) + 1e-6)
    direction = diff / dist
    return (dist - r) * direction
circle_var = CircleVar(id=0)
batched_var = CircleVar(id=jnp.zeros(N_POINTS, dtype=jnp.int32))

# Solve with each mode.
for name, cost_fn in [
    ("auto", circle_residual),
    ("forward", circle_residual_forward),
    ("reverse", circle_residual_reverse),
]:
    problem = jaxls.LeastSquaresProblem(
        [cost_fn(batched_var, points)], [circle_var]
    ).analyze()

    # Warm up JIT.
    _ = problem.solve(verbose=False)
    jax.block_until_ready(_)

    # Time the solve.
    start = time.time()
    for _ in range(10):
        sol = problem.solve(verbose=False)
        jax.block_until_ready(sol)
    elapsed = (time.time() - start) / 10

    params = sol[circle_var]
    print(
        f"{name:8s}: {elapsed * 1000:.2f}ms, center=({params[0]:.3f}, {params[1]:.3f}), radius={params[2]:.3f}"
    )
INFO     | Building optimization problem with 20 terms and 1 variables: 20 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 20 costs, 1 variables each: circle_residual
auto    : 0.95ms, center=(2.000, 3.000), radius=5.000
INFO     | Building optimization problem with 20 terms and 1 variables: 20 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 20 costs, 1 variables each: circle_residual_forward
forward : 0.98ms, center=(2.000, 3.000), radius=5.000
INFO     | Building optimization problem with 20 terms and 1 variables: 20 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 20 costs, 1 variables each: circle_residual_reverse
reverse : 0.90ms, center=(2.000, 3.000), radius=5.000

Linear solver choice#

jaxls supports three linear solvers. A quick guide:

Solver

Best for

conjugate_gradient (default)

Large sparse problems, GPU

dense_cholesky

Small problems (<500 variables)

cholmod

Large sparse problems on CPU

For details on solver internals and preconditioning options, see Sparse matrices.

circle_var = CircleVar(id=0)
batched_var = CircleVar(id=jnp.zeros(N_POINTS, dtype=jnp.int32))
problem = jaxls.LeastSquaresProblem(
    [circle_residual(batched_var, points)], [circle_var]
).analyze()

for solver in ["conjugate_gradient", "dense_cholesky", "cholmod"]:
    # Warm up.
    _ = problem.solve(verbose=False, linear_solver=solver)
    jax.block_until_ready(_)

    # Time it.
    start = time.time()
    for _ in range(10):
        sol = problem.solve(verbose=False, linear_solver=solver)
        jax.block_until_ready(sol)
    elapsed = (time.time() - start) / 10

    params = sol[circle_var]
    print(f"{solver:22s}: {elapsed * 1000:.2f}ms")
INFO     | Building optimization problem with 20 terms and 1 variables: 20 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 20 costs, 1 variables each: circle_residual
conjugate_gradient    : 0.92ms
dense_cholesky        : 0.76ms
cholmod               : 2.74ms

Numerical stability#

Common numerical issues and how to avoid them:

  1. sqrt(x) near zero: Use sqrt(x + epsilon) to avoid NaN gradients.

  2. Division by near-zero: Add epsilon to denominators.

  3. Large residuals: Consider robust loss functions (Huber, etc.).

# Unstable: sqrt(0) has undefined gradient.
@jaxls.Cost.factory
def distance_unstable(
    vals: jaxls.VarValues,
    circle_var: CircleVar,
    point: jax.Array,
) -> jax.Array:
    """Gradient is NaN when point == center."""
    params = vals[circle_var]
    center = params[:2]
    r = params[2]
    diff = point - center
    dist = jnp.sqrt(jnp.sum(diff**2))  # NaN gradient at diff=0.
    direction = diff / dist
    return (dist - r) * direction


# Stable: Add small epsilon for numerical stability.
@jaxls.Cost.factory
def distance_stable(
    vals: jaxls.VarValues,
    circle_var: CircleVar,
    point: jax.Array,
) -> jax.Array:
    """Epsilon prevents NaN gradients."""
    params = vals[circle_var]
    center = params[:2]
    r = params[2]
    diff = point - center
    dist = jnp.sqrt(jnp.sum(diff**2) + 1e-6)  # Safe.
    direction = diff / dist
    return (dist - r) * direction
# Demonstrate the issue: point exactly at center.
test_point = jnp.array([2.0, 3.0])  # Same as TRUE_CENTER.


# Direct gradient check.
def grad_unstable(params, point):
    diff = point - params[:2]
    return jnp.sqrt(jnp.sum(diff**2))


def grad_stable(params, point):
    diff = point - params[:2]
    return jnp.sqrt(jnp.sum(diff**2) + 1e-6)


params_at_center = jnp.array([2.0, 3.0, 5.0])  # Center matches point.

grad_u = jax.grad(grad_unstable)(params_at_center, test_point)
grad_s = jax.grad(grad_stable)(params_at_center, test_point)

print(f"Point at center: {test_point}")
print(f"Unstable gradient: {grad_u} (contains NaN!)")
print(f"Stable gradient:   {grad_s} (well-defined)")
Point at center: [2. 3.]
Unstable gradient: [nan nan  0.] (contains NaN!)
Stable gradient:   [-0. -0.  0.] (well-defined)

Debugging with return_summary#

When optimization doesn’t converge as expected, use return_summary=True to inspect the solve history.

circle_var = CircleVar(id=0)
batched_var = CircleVar(id=jnp.zeros(N_POINTS, dtype=jnp.int32))
problem = jaxls.LeastSquaresProblem(
    [circle_residual(batched_var, points)], [circle_var]
).analyze()

# Solve with summary.
solution, summary = problem.solve(verbose=False, return_summary=True)

print(f"Iterations: {int(summary.iterations)}")
print("\nTermination criteria (cost_delta, grad_mag, param_delta, max_iters):")
print(f"  {summary.termination_criteria}")
print("\nTermination deltas (cost_delta, grad_mag, param_delta):")
print(f"  {summary.termination_deltas}")
INFO     | Building optimization problem with 20 terms and 1 variables: 20 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 20 costs, 1 variables each: circle_residual
Iterations: 5

Termination criteria (cost_delta, grad_mag, param_delta, max_iters):
  [False False  True]

Termination deltas (cost_delta, grad_mag, param_delta):
  [4.6153852e-01 7.1525574e-06 5.3293412e-08]

Hide code cell source

# Plot cost history.
import plotly.graph_objects as go
from IPython.display import HTML

n_iters = int(summary.iterations)
cost_history = summary.cost_history[:n_iters]

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=list(range(1, n_iters + 1)),
        y=cost_history,
        mode="lines+markers",
        marker=dict(size=8),
        line=dict(width=2),
    )
)
fig.update_layout(
    title="Optimization Convergence",
    xaxis_title="Iteration",
    yaxis_title="Cost",
    height=350,
    margin=dict(t=40, b=40, l=60, r=40),
)
HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

Visualization of fitted circle#

Hide code cell source

import plotly.graph_objects as go
from IPython.display import HTML

params = solution[circle_var]
fitted_center = params[:2]
fitted_radius = params[2]

# Generate circles for plotting.
theta = jnp.linspace(0, 2 * jnp.pi, 100)
true_circle_x = TRUE_CENTER[0] + TRUE_RADIUS * jnp.cos(theta)
true_circle_y = TRUE_CENTER[1] + TRUE_RADIUS * jnp.sin(theta)
fitted_circle_x = fitted_center[0] + fitted_radius * jnp.cos(theta)
fitted_circle_y = fitted_center[1] + fitted_radius * jnp.sin(theta)

# Initial guess circle.
init_center = jnp.array([0.0, 0.0])
init_radius = 1.0
init_circle_x = init_center[0] + init_radius * jnp.cos(theta)
init_circle_y = init_center[1] + init_radius * jnp.sin(theta)

fig = go.Figure()

# Data points.
fig.add_trace(
    go.Scatter(
        x=points[:, 0],
        y=points[:, 1],
        mode="markers",
        marker=dict(size=10, color="#FF00FF"),
        name="Data points",
    )
)

# Initial guess circle.
fig.add_trace(
    go.Scatter(
        x=init_circle_x,
        y=init_circle_y,
        mode="lines",
        line=dict(color="orange", width=2, dash="dot"),
        name=f"Initial guess (r={init_radius:.1f})",
    )
)

# True circle.
fig.add_trace(
    go.Scatter(
        x=true_circle_x,
        y=true_circle_y,
        mode="lines",
        line=dict(color="gray", width=2, dash="dash"),
        name=f"True circle (r={TRUE_RADIUS:.1f})",
    )
)

# Fitted circle.
fig.add_trace(
    go.Scatter(
        x=fitted_circle_x,
        y=fitted_circle_y,
        mode="lines",
        line=dict(color="crimson", width=2),
        name=f"Fitted circle (r={float(fitted_radius):.3f})",
    )
)

# Centers.
fig.add_trace(
    go.Scatter(
        x=[float(init_center[0])],
        y=[float(init_center[1])],
        mode="markers",
        marker=dict(size=10, color="orange", symbol="x"),
        name="Initial center",
    )
)
fig.add_trace(
    go.Scatter(
        x=[TRUE_CENTER[0]],
        y=[TRUE_CENTER[1]],
        mode="markers",
        marker=dict(size=10, color="gray", symbol="x"),
        name="True center",
    )
)
fig.add_trace(
    go.Scatter(
        x=[float(fitted_center[0])],
        y=[float(fitted_center[1])],
        mode="markers",
        marker=dict(size=10, color="crimson", symbol="x"),
        name="Fitted center",
    )
)

fig.update_layout(
    title="Circle Fitting Result",
    xaxis_title="x",
    yaxis_title="y",
    xaxis=dict(scaleanchor="y", scaleratio=1),
    height=450,
    margin=dict(t=40, b=40, l=60, r=40),
    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
)
HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

Summary#

  1. Use batched construction for large problems by passing arrays instead of looping

  2. Prefer higher-dimensional residuals when possible (e.g., 2D error vectors instead of scalar distances)

  3. jac_mode="auto" usually works well; manual selection rarely needed

  4. Choose linear solver based on problem size: dense_cholesky for small, conjugate_gradient or cholmod for large

  5. Add epsilon to sqrt() and divisions for numerical stability

  6. Use return_summary=True to debug convergence issues

For more details, see jaxls.Cost, jaxls.LeastSquaresProblem, and jaxls.TerminationConfig.