Basics#
An introduction to nonlinear least squares optimization with jaxls.
Features used:
Varsubclassing for custom variable types@jaxls.Cost.factoryfor defining cost functionsLeastSquaresProblemfor building and solving optimization problemsVarValuesfor accessing solution values
import jax
import jax.numpy as jnp
import jaxls
The problem: circle fitting#
Given a set of noisy 2D points that lie approximately on a circle, we want to find the circle parameters (center and radius) that best fit the data.
This is a classic nonlinear least squares problem: we define residual vectors that measure the error for each data point, then minimize the sum of squared residual norms. In jaxls, a “cost” is a term in this objective; each cost computes a residual vector, and the solver minimizes \(\sum_i \|r_i(x)\|^2\).
# Generate noisy points on a circle.
true_cx, true_cy, true_r = 2.0, 1.5, 3.0 # True circle parameters.
num_points = 30
noise_std = 0.15
# Sample angles uniformly around the circle.
key = jax.random.PRNGKey(42)
angles = jnp.linspace(0, 2 * jnp.pi, num_points, endpoint=False)
# Generate points with Gaussian noise.
key, subkey = jax.random.split(key)
noise = jax.random.normal(subkey, shape=(num_points, 2)) * noise_std
points = (
jnp.stack(
[
true_cx + true_r * jnp.cos(angles),
true_cy + true_r * jnp.sin(angles),
],
axis=-1,
)
+ noise
)
print(f"Generated {num_points} noisy points around circle")
print(f"True parameters: center=({true_cx}, {true_cy}), radius={true_r}")
Generated 30 noisy points around circle
True parameters: center=(2.0, 1.5), radius=3.0
Defining variables#
Variables represent the unknowns we want to optimize. In jaxls, we define custom variable types by subclassing Var.
Each variable type specifies:
The data type it holds (any pytree, e.g.,
jax.Array, dataclasses, nested structures)A
default_factorythat creates an initial value
jaxls also supports non-Euclidean variables for optimization on manifolds like rotations. See Non-Euclidean variables for details.
class CircleVar(
jaxls.Var[jax.Array],
default_factory=lambda: jnp.array([0.0, 0.0, 1.0]),
):
"""Circle parameters: [cx, cy, radius]."""
# Create a variable instance with a unique ID.
circle_var = CircleVar(id=0)
print(f"Created variable: {circle_var}")
print(f"Default value: {CircleVar.default_factory()}")
Created variable: CircleVar(id=0)
Default value: [0. 0. 1.]
Defining costs#
Cost functions define what we’re optimizing. Use the @jaxls.Cost.factory decorator to create cost factories.
A cost function:
Takes a
VarValuesobject as its first argument (for looking up variable values)Takes additional arguments (variables and/or static data)
Returns a residual vector to minimize
The solver will minimize the sum of squared residuals across all costs.
@jaxls.Cost.factory
def circle_residual(
vals: jaxls.VarValues,
circle: CircleVar,
point: jax.Array,
) -> jax.Array:
"""Residual for a point's distance to the circle.
Args:
vals: Container for looking up current variable values.
circle: The circle variable to fit.
point: A 2D point that should lie on the circle.
Returns:
2D residual vector pointing from closest circle point to the observed point.
"""
params = vals[circle] # Look up the current circle parameters.
cx, cy, r = params[0], params[1], params[2]
# Vector from center to point.
diff = point - jnp.array([cx, cy])
dist_to_center = jnp.sqrt(jnp.sum(diff**2) + 1e-6)
direction = diff / dist_to_center
# 2D residual: error vector from closest circle point to observed point.
return (dist_to_center - r) * direction
Note that circle_residual is now a factory that creates cost objects. We call it with variable instances and static data to create costs:
# Create a cost for each observed point.
costs = [circle_residual(circle_var, point) for point in points]
print(f"Created {len(costs)} cost objects")
print(f"Example cost: {costs[0]}")
Created 30 cost objects
Example cost: Cost(compute_residual=<function Cost.factory.<locals>.decorator.<locals>.inner.<locals>.<lambda> at 0x1063d7380>, args=((CircleVar(id=0), Array([5.0908647, 1.6198566], dtype=float32)), {}), kind='l2_squared', jac_mode='auto', jac_batch_size=None, jac_custom_fn=None, jac_custom_with_cache_fn=None, name='circle_residual')
For large problems, batched construction (passing arrays of IDs and data) is more efficient. See Tips and gotchas for details.
Building problems#
A LeastSquaresProblem bundles costs and variables together.
We can call .show() to visualize the problem structure. In notebooks this displays inline; outside notebooks it opens in a browser.
# Build the problem.
problem = jaxls.LeastSquaresProblem(costs, [circle_var])
# Visualize the problem structure.
problem.show()
# Analyze the problem structure for efficient solving.
problem = problem.analyze()
INFO | Building optimization problem with 30 terms and 1 variables: 30 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO | Vectorizing group with 30 costs, 1 variables each: circle_residual
Solving#
Call .solve() to run the Levenberg-Marquardt optimizer. The solver iteratively adjusts the variables to minimize the sum of squared residuals.
# Solve the problem (uses default initial values from the variable's default_factory).
solution = problem.solve()
# Access the solution using the variable as a key.
result = solution[circle_var]
est_cx, est_cy, est_r = result[0], result[1], result[2]
print(
f"\nEstimated: center=({float(est_cx):.3f}, {float(est_cy):.3f}), radius={float(est_r):.3f}"
)
print(f"True: center=({true_cx:.3f}, {true_cy:.3f}), radius={true_r:.3f}")
print(
f"\nCenter error: {float(jnp.sqrt((est_cx - true_cx) ** 2 + (est_cy - true_cy) ** 2)):.4f}"
)
print(f"Radius error: {float(abs(est_r - true_r)):.4f}")
INFO | step #0: cost=270.3013 lambd=0.0005 inexact_tol=1.0e-02
INFO | - circle_residual(30): 270.30130 (avg 4.50502)
INFO | accepted=True ATb_norm=1.13e+02 cost_prev=270.3013 cost_new=9.5613
INFO | step #1: cost=9.5613 lambd=0.0003 inexact_tol=1.0e-02
INFO | - circle_residual(30): 9.56126 (avg 0.15935)
INFO | accepted=True ATb_norm=1.66e+01 cost_prev=9.5613 cost_new=0.4361
INFO | step #2: cost=0.4361 lambd=0.0001 inexact_tol=1.0e-02
INFO | - circle_residual(30): 0.43613 (avg 0.00727)
INFO | accepted=True ATb_norm=1.18e+00 cost_prev=0.4361 cost_new=0.4008
INFO | step #3: cost=0.4008 lambd=0.0001 inexact_tol=4.5e-03
INFO | - circle_residual(30): 0.40079 (avg 0.00668)
INFO | accepted=False ATb_norm=1.73e-03 cost_prev=0.4008 cost_new=0.4008
INFO | Terminated @ iteration #4: cost=0.4008 criteria=[1 0 0], term_deltas=1.5e-07,1.7e-03,2.5e-05
Estimated: center=(1.922, 1.475), radius=2.996
True: center=(2.000, 1.500), radius=3.000
Center error: 0.0819
Radius error: 0.0045
Visualization#
Summary#
The key steps for solving nonlinear least squares problems with jaxls:
Define variables by subclassing
Varwith adefault_factoryDefine costs using
@jaxls.Cost.factory, returning residual vectorsBuild the problem with
LeastSquaresProblemand call.analyze()Solve with
.solve()and access results via theVarValuesobject
For more, see:
Tips and gotchas: Batched construction, residual dimensions, solver selection
Constraints: Equality and inequality constraints
Non-Euclidean variables: Lie group variables for rotations and poses
Robust costs: Handling outliers with M-estimators