Robust costs#
Using M-estimators and jax.lax.stop_gradient to handle outliers robustly in jaxls.
Standard least squares is sensitive to outliers: a few bad measurements can pull the solution away from the true parameters. By using stop_gradient to compute adaptive weights within cost functions, we can down-weight outliers and recover accurate estimates even with corrupted data.
Features used:
Varsubclass for circle parameters@jaxls.Cost.factoryfor robust residualsjax.lax.stop_gradientfor IRLS-style weight updates
import jax
import jax.numpy as jnp
import jaxls
import numpy as np
The outlier problem#
Consider fitting a circle to 2D points. With clean data, least squares works well. But real-world data often contains outliers – points that don’t follow the expected model due to sensor errors, misassociations, or other anomalies.
Standard least squares minimizes the sum of squared residuals:
The squaring amplifies large residuals, giving outliers disproportionate influence.
# Generate synthetic circle data with outliers.
np.random.seed(42)
# Ground truth circle.
true_cx, true_cy, true_r = 1.0, 1.0, 2.0
# Inlier points (on the circle with small noise).
n_inliers = 40
theta_inliers = np.random.uniform(0, 2 * np.pi, n_inliers)
noise_inliers = np.random.normal(0, 0.1, n_inliers)
inlier_x = true_cx + (true_r + noise_inliers) * np.cos(theta_inliers)
inlier_y = true_cy + (true_r + noise_inliers) * np.sin(theta_inliers)
# Outlier points (scattered far from the circle).
n_outliers = 10
outlier_x = np.random.uniform(-4, 7, n_outliers)
outlier_y = np.random.uniform(-4, 7, n_outliers)
# Combine all points.
all_x = np.concatenate([inlier_x, outlier_x])
all_y = np.concatenate([inlier_y, outlier_y])
points = jnp.stack([all_x, all_y], axis=-1)
n_points = len(points)
# Track which points are outliers for visualization.
is_outlier = np.array([False] * n_inliers + [True] * n_outliers)
print(
f"Generated {n_inliers} inliers and {n_outliers} outliers ({n_outliers / n_points * 100:.0f}% outliers)"
)
print(f"True circle: center=({true_cx}, {true_cy}), radius={true_r}")
Generated 40 inliers and 10 outliers (20% outliers)
True circle: center=(1.0, 1.0), radius=2.0
Standard least squares#
First, let’s see how standard (unweighted) least squares performs:
class CircleVar(
jaxls.Var[jax.Array], default_factory=lambda: jnp.array([0.0, 0.0, 1.0])
):
"""Circle parameters: [cx, cy, r]."""
@jaxls.Cost.factory
def circle_residual(
vals: jaxls.VarValues,
circle: CircleVar,
point: jax.Array,
) -> jax.Array:
"""2D residual: error vector from closest circle point to observed point."""
params = vals[circle]
cx, cy, r = params[0], params[1], params[2]
diff = point - jnp.array([cx, cy])
dist = jnp.sqrt(jnp.sum(diff**2) + 1e-8)
direction = diff / dist
return (dist - r) * direction
# Create variable and costs.
circle_var = CircleVar(id=0)
# Initial guess: centroid of points, average distance as radius.
centroid = jnp.mean(points, axis=0)
avg_dist = jnp.mean(jnp.sqrt(jnp.sum((points - centroid) ** 2, axis=-1)))
initial_params = jnp.array([centroid[0], centroid[1], avg_dist])
# Build problem with batched point indices.
costs_standard = [
circle_residual(CircleVar(id=jnp.zeros(n_points, dtype=jnp.int32)), points)
]
initial_vals = jaxls.VarValues.make([circle_var.with_value(initial_params)])
# Build the problem.
problem_standard = jaxls.LeastSquaresProblem(costs_standard, [circle_var])
# Visualize the problem structure structure.
problem_standard.show()
# Analyze and solve.
problem_standard = problem_standard.analyze()
solution_standard = problem_standard.solve(initial_vals)
params_standard = solution_standard[circle_var]
print(
f"Standard LS result: center=({params_standard[0]:.3f}, {params_standard[1]:.3f}), radius={params_standard[2]:.3f}"
)
print(f"True parameters: center=({true_cx:.3f}, {true_cy:.3f}), radius={true_r:.3f}")
INFO | Building optimization problem with 50 terms and 1 variables: 50 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO | Vectorizing group with 50 costs, 1 variables each: circle_residual
INFO | step #0: cost=62.2762 lambd=0.0005 inexact_tol=1.0e-02
INFO | - circle_residual(50): 62.27618 (avg 0.62276)
INFO | accepted=True ATb_norm=8.01e+00 cost_prev=62.2762 cost_new=60.3239
INFO | step #1: cost=60.3239 lambd=0.0003 inexact_tol=1.0e-02
INFO | - circle_residual(50): 60.32394 (avg 0.60324)
INFO | accepted=True ATb_norm=1.71e+00 cost_prev=60.3239 cost_new=60.2327
INFO | step #2: cost=60.2327 lambd=0.0001 inexact_tol=1.0e-02
INFO | - circle_residual(50): 60.23269 (avg 0.60233)
INFO | accepted=True ATb_norm=3.43e-01 cost_prev=60.2327 cost_new=60.2291
INFO | step #3: cost=60.2291 lambd=0.0001 inexact_tol=1.0e-02
INFO | - circle_residual(50): 60.22906 (avg 0.60229)
INFO | accepted=True ATb_norm=6.96e-02 cost_prev=60.2291 cost_new=60.2289
INFO | Terminated @ iteration #4: cost=60.2289 criteria=[1 0 0], term_deltas=2.5e-06,6.0e-02,7.6e-04
Standard LS result: center=(0.639, 1.114), radius=2.445
True parameters: center=(1.000, 1.000), radius=2.000
Robust cost functions (M-estimators)#
M-estimators replace the squared loss \(r^2\) with a robust function \(\rho(r)\) that grows more slowly for large residuals. Common choices:
Huber loss: Quadratic for small residuals, linear for large:
Cauchy/Lorentzian: Soft down-weighting of outliers:
Geman-McClure: Even more aggressive outlier rejection:
def huber_weight(residual: jax.Array, k: float = 1.345) -> jax.Array:
"""Huber weight: 1 for small residuals, k/|r| for large.
The default k=1.345 gives 95% efficiency for Gaussian data.
"""
abs_r = jnp.abs(residual) + 1e-8
return jnp.where(abs_r <= k, 1.0, k / abs_r)
def cauchy_weight(residual: jax.Array, c: float = 2.385) -> jax.Array:
"""Cauchy/Lorentzian weight: 1 / (1 + (r/c)^2).
Provides soft down-weighting that never fully rejects points.
"""
return 1.0 / (1.0 + (residual / c) ** 2)
def geman_mcclure_weight(residual: jax.Array, sigma: float = 1.0) -> jax.Array:
"""Geman-McClure weight: 1 / (1 + r^2)^2.
Aggressive outlier rejection -- weight drops quickly for large residuals.
"""
r_scaled = residual / sigma
return 1.0 / (1.0 + r_scaled**2) ** 2
IRLS with jaxls#
Iteratively reweighted least squares (IRLS) is an algorithm for robust estimation. The key to implementing IRLS in jaxls is using jax.lax.stop_gradient on the weights. This tells JAX to treat the weights as constants when computing Jacobians, so the solver sees a weighted least squares problem at each iteration. As the solution changes across iterations, the weights automatically update, implementing IRLS without an explicit outer loop.
The cost function computes:
Residual \(r_i = \|p - p_{\text{circle}}\|\) for each point
Weight \(w_i = \psi(r_i)\) using the M-estimator, with
stop_gradientWeighted residual \(\sqrt{w_i} \cdot r_i\)
Since jaxls squares residuals, the objective becomes \(\sum_i w_i r_i^2\), exactly the IRLS formulation.
def make_robust_circle_cost(weight_fn: callable):
"""Create a robust circle cost with IRLS weights via stop_gradient.
Args:
weight_fn: Function mapping residual magnitude to weight (e.g., cauchy_weight).
Returns:
A jaxls cost factory that applies robust weighting.
"""
@jaxls.Cost.factory
def robust_circle_residual(
vals: jaxls.VarValues,
circle: CircleVar,
point: jax.Array,
) -> jax.Array:
"""2D residual with automatic IRLS weighting."""
params = vals[circle]
cx, cy, r = params[0], params[1], params[2]
diff = point - jnp.array([cx, cy])
dist = jnp.sqrt(jnp.sum(diff**2) + 1e-8)
direction = diff / dist
residual = dist - r
# Compute weight from current residual, but stop gradients.
# This makes the optimizer treat weights as constants, implementing IRLS.
weight = jax.lax.stop_gradient(weight_fn(residual))
sqrt_weight = jnp.sqrt(weight + 1e-8)
return sqrt_weight * residual * direction
return robust_circle_residual
def compute_weights(
params: jax.Array, points: jax.Array, weight_fn: callable
) -> jax.Array:
"""Compute IRLS weights for visualization."""
cx, cy, r = params[0], params[1], params[2]
dists = jnp.sqrt((points[:, 0] - cx) ** 2 + (points[:, 1] - cy) ** 2 + 1e-8)
residuals = dists - r
return weight_fn(residuals)
Robust fitting with Cauchy weights#
With the stop_gradient approach, robust fitting is just a single solve() call:
# Create robust cost with Cauchy weights.
cauchy_weight_fn = lambda r: cauchy_weight(r, c=0.5)
robust_cost = make_robust_circle_cost(cauchy_weight_fn)
# Build and solve (just one call - IRLS happens automatically via solver iterations).
costs_irls = [robust_cost(CircleVar(id=jnp.zeros(n_points, dtype=jnp.int32)), points)]
initial_vals = jaxls.VarValues.make([circle_var.with_value(initial_params)])
problem_irls = jaxls.LeastSquaresProblem(costs_irls, [circle_var]).analyze()
solution_irls = problem_irls.solve(initial_vals)
params_irls = solution_irls[circle_var]
weights_irls = compute_weights(params_irls, points, cauchy_weight_fn)
print(
f"IRLS result: center=({params_irls[0]:.3f}, {params_irls[1]:.3f}), radius={params_irls[2]:.3f}"
)
print(f"True parameters: center=({true_cx:.3f}, {true_cy:.3f}), radius={true_r:.3f}")
print(
f"\nCenter error: {jnp.sqrt((params_irls[0] - true_cx) ** 2 + (params_irls[1] - true_cy) ** 2):.4f}"
)
print(f"Radius error: {jnp.abs(params_irls[2] - true_r):.4f}")
INFO | Building optimization problem with 50 terms and 1 variables: 50 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO | Vectorizing group with 50 costs, 1 variables each: robust_circle_residual
INFO | step #0: cost=6.1968 lambd=0.0005 inexact_tol=1.0e-02
INFO | - robust_circle_residual(50): 6.19684 (avg 0.06197)
INFO | accepted=True ATb_norm=9.32e+00 cost_prev=6.1968 cost_new=2.7644
INFO | step #1: cost=2.7644 lambd=0.0003 inexact_tol=1.0e-02
INFO | - robust_circle_residual(50): 2.76435 (avg 0.02764)
INFO | accepted=True ATb_norm=2.87e+00 cost_prev=2.7644 cost_new=2.4958
INFO | step #2: cost=2.4958 lambd=0.0001 inexact_tol=1.0e-02
INFO | - robust_circle_residual(50): 2.49581 (avg 0.02496)
INFO | accepted=True ATb_norm=2.61e-01 cost_prev=2.4958 cost_new=2.4868
INFO | step #3: cost=2.4868 lambd=0.0001 inexact_tol=7.5e-03
INFO | - robust_circle_residual(50): 2.48685 (avg 0.02487)
INFO | accepted=True ATb_norm=2.36e-02 cost_prev=2.4868 cost_new=2.4862
INFO | step #4: cost=2.4862 lambd=0.0000 inexact_tol=7.3e-03
INFO | - robust_circle_residual(50): 2.48620 (avg 0.02486)
INFO | accepted=True ATb_norm=2.40e-03 cost_prev=2.4862 cost_new=2.4861
INFO | step #5: cost=2.4861 lambd=0.0000 inexact_tol=7.3e-03
INFO | - robust_circle_residual(50): 2.48614 (avg 0.02486)
INFO | accepted=False ATb_norm=2.81e-04 cost_prev=2.4861 cost_new=2.4861
INFO | Terminated @ iteration #6: cost=2.4861 criteria=[1 0 0], term_deltas=1.9e-06,2.7e-04,4.9e-06
IRLS result: center=(0.975, 1.009), radius=2.015
True parameters: center=(1.000, 1.000), radius=2.000
Center error: 0.0266
Radius error: 0.0150
Visualization#
Compare standard least squares (pulled by outliers) vs IRLS (robust to outliers):
Weight visualization#
IRLS identifies outliers by assigning them low weights. Let’s visualize the final weights:
Comparing M-estimators#
Different weight functions give different levels of robustness:
# Fit with different M-estimators.
results = {"Standard LS": params_standard}
for name, weight_fn in [
("Huber", lambda r: huber_weight(r, k=0.5)),
("Cauchy", lambda r: cauchy_weight(r, c=0.5)),
("Geman-McClure", lambda r: geman_mcclure_weight(r, sigma=0.5)),
]:
robust_cost = make_robust_circle_cost(weight_fn)
costs = [robust_cost(CircleVar(id=jnp.zeros(n_points, dtype=jnp.int32)), points)]
problem = jaxls.LeastSquaresProblem(costs, [circle_var]).analyze()
solution = problem.solve(initial_vals, verbose=False)
results[name] = solution[circle_var]
print(
f"{'Method':<15} {'cx':>8} {'cy':>8} {'r':>8} {'Center err':>12} {'Radius err':>12}"
)
print("-" * 70)
print(
f"{'True':<15} {true_cx:>8.3f} {true_cy:>8.3f} {true_r:>8.3f} {'-':>12} {'-':>12}"
)
for name, params in results.items():
center_err = float(
jnp.sqrt((params[0] - true_cx) ** 2 + (params[1] - true_cy) ** 2)
)
radius_err = float(jnp.abs(params[2] - true_r))
print(
f"{name:<15} {float(params[0]):>8.3f} {float(params[1]):>8.3f} {float(params[2]):>8.3f} {center_err:>12.4f} {radius_err:>12.4f}"
)
INFO | Building optimization problem with 50 terms and 1 variables: 50 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO | Vectorizing group with 50 costs, 1 variables each: robust_circle_residual
INFO | Building optimization problem with 50 terms and 1 variables: 50 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO | Vectorizing group with 50 costs, 1 variables each: robust_circle_residual
INFO | Building optimization problem with 50 terms and 1 variables: 50 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO | Vectorizing group with 50 costs, 1 variables each: robust_circle_residual
Method cx cy r Center err Radius err
----------------------------------------------------------------------
True 1.000 1.000 2.000 - -
Standard LS 0.639 1.114 2.445 0.3781 0.4446
Huber 0.918 1.032 2.087 0.0880 0.0871
Cauchy 0.975 1.009 2.015 0.0266 0.0150
Geman-McClure 0.990 1.000 1.999 0.0104 0.0010
Summary#
Robust estimation in jaxls using jax.lax.stop_gradient:
Standard least squares gives outliers disproportionate influence due to the squared loss.
M-estimators (Huber, Cauchy, Geman-McClure) define weight functions that down-weight large residuals.
Using
stop_gradienton weights implements IRLS automatically within the solver’s iterations – no explicit outer loop needed.Huber provides convex robustness; Cauchy gives smooth down-weighting; Geman-McClure aggressively rejects outliers.
Implementation pattern:
weight = jax.lax.stop_gradient(weight_fn(residual))
return jnp.sqrt(weight) * residual
This makes the optimizer treat weights as constants while still updating them as the solution evolves. The solver’s iterations naturally perform IRLS.
For more details, see jaxls.Var, jaxls.Cost, and jaxls.LeastSquaresProblem.