Basics#

An introduction to nonlinear least squares optimization with jaxls.

Features used:

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import jax
import jax.numpy as jnp
import jaxls

The problem: circle fitting#

Given a set of noisy 2D points that lie approximately on a circle, we want to find the circle parameters (center and radius) that best fit the data.

This is a classic nonlinear least squares problem: we define residual vectors that measure the error for each data point, then minimize the sum of squared residual norms. In jaxls, a “cost” is a term in this objective; each cost computes a residual vector, and the solver minimizes \(\sum_i \|r_i(x)\|^2\).

# Generate noisy points on a circle.
true_cx, true_cy, true_r = 2.0, 1.5, 3.0  # True circle parameters.
num_points = 30
noise_std = 0.15

# Sample angles uniformly around the circle.
key = jax.random.PRNGKey(42)
angles = jnp.linspace(0, 2 * jnp.pi, num_points, endpoint=False)

# Generate points with Gaussian noise.
key, subkey = jax.random.split(key)
noise = jax.random.normal(subkey, shape=(num_points, 2)) * noise_std

points = (
    jnp.stack(
        [
            true_cx + true_r * jnp.cos(angles),
            true_cy + true_r * jnp.sin(angles),
        ],
        axis=-1,
    )
    + noise
)

print(f"Generated {num_points} noisy points around circle")
print(f"True parameters: center=({true_cx}, {true_cy}), radius={true_r}")
Generated 30 noisy points around circle
True parameters: center=(2.0, 1.5), radius=3.0

Defining variables#

Variables represent the unknowns we want to optimize. In jaxls, we define custom variable types by subclassing Var.

Each variable type specifies:

  • The data type it holds (any pytree, e.g., jax.Array, dataclasses, nested structures)

  • A default_factory that creates an initial value

jaxls also supports non-Euclidean variables for optimization on manifolds like rotations. See Non-Euclidean variables for details.

class CircleVar(
    jaxls.Var[jax.Array],
    default_factory=lambda: jnp.array([0.0, 0.0, 1.0]),
):
    """Circle parameters: [cx, cy, radius]."""


# Create a variable instance with a unique ID.
circle_var = CircleVar(id=0)

print(f"Created variable: {circle_var}")
print(f"Default value: {CircleVar.default_factory()}")
Created variable: CircleVar(id=0)
Default value: [0. 0. 1.]

Defining costs#

Cost functions define what we’re optimizing. Use the @jaxls.Cost.factory decorator to create cost factories.

A cost function:

  • Takes a VarValues object as its first argument (for looking up variable values)

  • Takes additional arguments (variables and/or static data)

  • Returns a residual vector to minimize

The solver will minimize the sum of squared residuals across all costs.

@jaxls.Cost.factory
def circle_residual(
    vals: jaxls.VarValues,
    circle: CircleVar,
    point: jax.Array,
) -> jax.Array:
    """Residual for a point's distance to the circle.

    Args:
        vals: Container for looking up current variable values.
        circle: The circle variable to fit.
        point: A 2D point that should lie on the circle.

    Returns:
        2D residual vector pointing from closest circle point to the observed point.
    """
    params = vals[circle]  # Look up the current circle parameters.
    cx, cy, r = params[0], params[1], params[2]

    # Vector from center to point.
    diff = point - jnp.array([cx, cy])
    dist_to_center = jnp.sqrt(jnp.sum(diff**2) + 1e-6)
    direction = diff / dist_to_center

    # 2D residual: error vector from closest circle point to observed point.
    return (dist_to_center - r) * direction

Note that circle_residual is now a factory that creates cost objects. We call it with variable instances and static data to create costs:

# Create a cost for each observed point.
costs = [circle_residual(circle_var, point) for point in points]

print(f"Created {len(costs)} cost objects")
print(f"Example cost: {costs[0]}")
Created 30 cost objects
Example cost: Cost(compute_residual=<function Cost.factory.<locals>.decorator.<locals>.inner.<locals>.<lambda> at 0x1063d7380>, args=((CircleVar(id=0), Array([5.0908647, 1.6198566], dtype=float32)), {}), kind='l2_squared', jac_mode='auto', jac_batch_size=None, jac_custom_fn=None, jac_custom_with_cache_fn=None, name='circle_residual')

For large problems, batched construction (passing arrays of IDs and data) is more efficient. See Tips and gotchas for details.

Building problems#

A LeastSquaresProblem bundles costs and variables together.

We can call .show() to visualize the problem structure. In notebooks this displays inline; outside notebooks it opens in a browser.

# Build the problem.
problem = jaxls.LeastSquaresProblem(costs, [circle_var])

# Visualize the problem structure.
problem.show()
# Analyze the problem structure for efficient solving.
problem = problem.analyze()
INFO     | Building optimization problem with 30 terms and 1 variables: 30 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 30 costs, 1 variables each: circle_residual

Solving#

Call .solve() to run the Levenberg-Marquardt optimizer. The solver iteratively adjusts the variables to minimize the sum of squared residuals.

# Solve the problem (uses default initial values from the variable's default_factory).
solution = problem.solve()

# Access the solution using the variable as a key.
result = solution[circle_var]
est_cx, est_cy, est_r = result[0], result[1], result[2]

print(
    f"\nEstimated: center=({float(est_cx):.3f}, {float(est_cy):.3f}), radius={float(est_r):.3f}"
)
print(f"True:      center=({true_cx:.3f}, {true_cy:.3f}), radius={true_r:.3f}")
print(
    f"\nCenter error: {float(jnp.sqrt((est_cx - true_cx) ** 2 + (est_cy - true_cy) ** 2)):.4f}"
)
print(f"Radius error: {float(abs(est_r - true_r)):.4f}")
INFO     |  step #0: cost=270.3013 lambd=0.0005 inexact_tol=1.0e-02
INFO     |      - circle_residual(30): 270.30130 (avg 4.50502)
INFO     |      accepted=True ATb_norm=1.13e+02 cost_prev=270.3013 cost_new=9.5613
INFO     |  step #1: cost=9.5613 lambd=0.0003 inexact_tol=1.0e-02
INFO     |      - circle_residual(30): 9.56126 (avg 0.15935)
INFO     |      accepted=True ATb_norm=1.66e+01 cost_prev=9.5613 cost_new=0.4361
INFO     |  step #2: cost=0.4361 lambd=0.0001 inexact_tol=1.0e-02
INFO     |      - circle_residual(30): 0.43613 (avg 0.00727)
INFO     |      accepted=True ATb_norm=1.18e+00 cost_prev=0.4361 cost_new=0.4008
INFO     |  step #3: cost=0.4008 lambd=0.0001 inexact_tol=4.5e-03
INFO     |      - circle_residual(30): 0.40079 (avg 0.00668)
INFO     |      accepted=False ATb_norm=1.73e-03 cost_prev=0.4008 cost_new=0.4008
INFO     | Terminated @ iteration #4: cost=0.4008 criteria=[1 0 0], term_deltas=1.5e-07,1.7e-03,2.5e-05

Estimated: center=(1.922, 1.475), radius=2.996
True:      center=(2.000, 1.500), radius=3.000

Center error: 0.0819
Radius error: 0.0045

Visualization#

Hide code cell source

import plotly.graph_objects as go
from IPython.display import HTML


def make_circle_trace(
    cx: float, cy: float, r: float, name: str, color: str, dash: str = "solid"
) -> go.Scatter:
    """Create a Plotly trace for a circle.

    Args:
        cx: Circle center x coordinate.
        cy: Circle center y coordinate.
        r: Circle radius.
        name: Legend name.
        color: Line color.
        dash: Line dash style.

    Returns:
        Plotly Scatter trace for the circle.
    """
    theta = jnp.linspace(0, 2 * jnp.pi, 100)
    x = cx + r * jnp.cos(theta)
    y = cy + r * jnp.sin(theta)
    return go.Scatter(
        x=x,
        y=y,
        mode="lines",
        name=name,
        line=dict(color=color, width=2, dash=dash),
    )


fig = go.Figure()

# Noisy data points.
fig.add_trace(
    go.Scatter(
        x=points[:, 0],
        y=points[:, 1],
        mode="markers",
        name="Noisy points",
        marker=dict(size=8, color="steelblue"),
        hovertemplate="(%{x:.2f}, %{y:.2f})",
    )
)

# True circle.
fig.add_trace(
    make_circle_trace(true_cx, true_cy, true_r, "True circle", "green", "dash")
)

# Fitted circle.
fig.add_trace(
    make_circle_trace(
        float(est_cx), float(est_cy), float(est_r), "Fitted circle", "crimson"
    )
)

# Center points.
fig.add_trace(
    go.Scatter(
        x=[true_cx, float(est_cx)],
        y=[true_cy, float(est_cy)],
        mode="markers",
        marker=dict(size=10, symbol="x", color=["green", "crimson"]),
        name="Centers",
        showlegend=False,
    )
)

fig.update_layout(
    title="Circle Fitting with jaxls",
    xaxis_title="x",
    yaxis_title="y",
    xaxis=dict(scaleanchor="y", scaleratio=1),
    height=450,
    margin=dict(t=40, b=40, l=40, r=40),
    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
)

HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

Summary#

The key steps for solving nonlinear least squares problems with jaxls:

  1. Define variables by subclassing Var with a default_factory

  2. Define costs using @jaxls.Cost.factory, returning residual vectors

  3. Build the problem with LeastSquaresProblem and call .analyze()

  4. Solve with .solve() and access results via the VarValues object

For more, see: