# Robust costs

Using M-estimators and `jax.lax.stop_gradient` to handle outliers robustly in jaxls.

Standard least squares is sensitive to outliers: a few bad measurements can pull the solution away from the true parameters. By using `stop_gradient` to compute adaptive weights within cost functions, we can down-weight outliers and recover accurate estimates even with corrupted data.

Features used:
- {class}`~jaxls.Var` subclass for circle parameters
- {func}`@jaxls.Cost.factory <jaxls.Cost.factory>` for robust residuals
- `jax.lax.stop_gradient` for IRLS-style weight updates

In [None]:
import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");

In [None]:
import jax
import jax.numpy as jnp
import jaxls
import numpy as np

## The outlier problem

Consider fitting a circle to 2D points. With clean data, least squares works well.
But real-world data often contains outliers -- points that don't follow the expected
model due to sensor errors, misassociations, or other anomalies.

Standard least squares minimizes the sum of squared residuals:

$$\min_\theta \sum_i r_i(\theta)^2$$

The squaring amplifies large residuals, giving outliers disproportionate influence.

In [None]:
# Generate synthetic circle data with outliers.
np.random.seed(42)

# Ground truth circle.
true_cx, true_cy, true_r = 1.0, 1.0, 2.0

# Inlier points (on the circle with small noise).
n_inliers = 40
theta_inliers = np.random.uniform(0, 2 * np.pi, n_inliers)
noise_inliers = np.random.normal(0, 0.1, n_inliers)
inlier_x = true_cx + (true_r + noise_inliers) * np.cos(theta_inliers)
inlier_y = true_cy + (true_r + noise_inliers) * np.sin(theta_inliers)

# Outlier points (scattered far from the circle).
n_outliers = 10
outlier_x = np.random.uniform(-4, 7, n_outliers)
outlier_y = np.random.uniform(-4, 7, n_outliers)

# Combine all points.
all_x = np.concatenate([inlier_x, outlier_x])
all_y = np.concatenate([inlier_y, outlier_y])
points = jnp.stack([all_x, all_y], axis=-1)
n_points = len(points)

# Track which points are outliers for visualization.
is_outlier = np.array([False] * n_inliers + [True] * n_outliers)

print(
    f"Generated {n_inliers} inliers and {n_outliers} outliers ({n_outliers / n_points * 100:.0f}% outliers)"
)
print(f"True circle: center=({true_cx}, {true_cy}), radius={true_r}")

## Standard least squares

First, let's see how standard (unweighted) least squares performs:

In [None]:
class CircleVar(
    jaxls.Var[jax.Array], default_factory=lambda: jnp.array([0.0, 0.0, 1.0])
):
    """Circle parameters: [cx, cy, r]."""


@jaxls.Cost.factory
def circle_residual(
    vals: jaxls.VarValues,
    circle: CircleVar,
    point: jax.Array,
) -> jax.Array:
    """2D residual: error vector from closest circle point to observed point."""
    params = vals[circle]
    cx, cy, r = params[0], params[1], params[2]
    diff = point - jnp.array([cx, cy])
    dist = jnp.sqrt(jnp.sum(diff**2) + 1e-8)
    direction = diff / dist
    return (dist - r) * direction

In [None]:
# Create variable and costs.
circle_var = CircleVar(id=0)

# Initial guess: centroid of points, average distance as radius.
centroid = jnp.mean(points, axis=0)
avg_dist = jnp.mean(jnp.sqrt(jnp.sum((points - centroid) ** 2, axis=-1)))
initial_params = jnp.array([centroid[0], centroid[1], avg_dist])

# Build problem with batched point indices.
costs_standard = [
    circle_residual(CircleVar(id=jnp.zeros(n_points, dtype=jnp.int32)), points)
]

initial_vals = jaxls.VarValues.make([circle_var.with_value(initial_params)])

# Build the problem.
problem_standard = jaxls.LeastSquaresProblem(costs_standard, [circle_var])

# Visualize the problem structure structure.
problem_standard.show()

In [None]:
# Analyze and solve.
problem_standard = problem_standard.analyze()
solution_standard = problem_standard.solve(initial_vals)

params_standard = solution_standard[circle_var]
print(
    f"Standard LS result: center=({params_standard[0]:.3f}, {params_standard[1]:.3f}), radius={params_standard[2]:.3f}"
)
print(f"True parameters:    center=({true_cx:.3f}, {true_cy:.3f}), radius={true_r:.3f}")

In [None]:
import plotly.graph_objects as go
from IPython.display import HTML


def make_circle_trace(
    cx: float, cy: float, r: float, name: str, color: str, dash: str = "solid"
) -> go.Scatter:
    """Create a circle trace for plotting."""
    theta = np.linspace(0, 2 * np.pi, 100)
    x = cx + r * np.cos(theta)
    y = cy + r * np.sin(theta)
    return go.Scatter(
        x=x,
        y=y,
        mode="lines",
        name=name,
        line=dict(color=color, width=2, dash=dash),
    )


fig_standard = go.Figure()

# Inlier points.
fig_standard.add_trace(
    go.Scatter(
        x=all_x[~is_outlier],
        y=all_y[~is_outlier],
        mode="markers",
        marker=dict(size=8, color="#2196F3"),
        name="Inliers",
    )
)

# Outlier points.
fig_standard.add_trace(
    go.Scatter(
        x=all_x[is_outlier],
        y=all_y[is_outlier],
        mode="markers",
        marker=dict(size=10, color="#F44336", symbol="x"),
        name="Outliers",
    )
)

# True circle.
fig_standard.add_trace(
    make_circle_trace(true_cx, true_cy, true_r, "True circle", "#4CAF50", "dash")
)

# Standard LS result.
fig_standard.add_trace(
    make_circle_trace(
        float(params_standard[0]),
        float(params_standard[1]),
        float(params_standard[2]),
        "Standard LS",
        "#FF9800",
    )
)

fig_standard.update_xaxes(title_text="x", scaleanchor="y", scaleratio=1)
fig_standard.update_yaxes(title_text="y")
fig_standard.update_layout(
    height=450,
    margin=dict(t=40, b=40, l=60, r=40),
    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
)
HTML(fig_standard.to_html(full_html=False, include_plotlyjs="cdn"))

## Robust cost functions (M-estimators)

M-estimators replace the squared loss $r^2$ with a robust function $\rho(r)$ that grows
more slowly for large residuals. Common choices:

**Huber loss**: Quadratic for small residuals, linear for large:

$$
\rho_\text{Huber}(r) = \begin{cases} \frac{1}{2}r^2 & |r| \leq k \\ k(|r| - \frac{k}{2}) & |r| > k \end{cases}
$$

**Cauchy/Lorentzian**: Soft down-weighting of outliers:

$$
\rho_\text{Cauchy}(r) = \frac{c^2}{2} \log\left(1 + \left(\frac{r}{c}\right)^2\right)
$$

**Geman-McClure**: Even more aggressive outlier rejection:

$$
\rho_\text{GM}(r) = \frac{r^2/2}{1 + r^2}
$$

In [None]:
def huber_weight(residual: jax.Array, k: float = 1.345) -> jax.Array:
    """Huber weight: 1 for small residuals, k/|r| for large.

    The default k=1.345 gives 95% efficiency for Gaussian data.
    """
    abs_r = jnp.abs(residual) + 1e-8
    return jnp.where(abs_r <= k, 1.0, k / abs_r)


def cauchy_weight(residual: jax.Array, c: float = 2.385) -> jax.Array:
    """Cauchy/Lorentzian weight: 1 / (1 + (r/c)^2).

    Provides soft down-weighting that never fully rejects points.
    """
    return 1.0 / (1.0 + (residual / c) ** 2)


def geman_mcclure_weight(residual: jax.Array, sigma: float = 1.0) -> jax.Array:
    """Geman-McClure weight: 1 / (1 + r^2)^2.

    Aggressive outlier rejection -- weight drops quickly for large residuals.
    """
    r_scaled = residual / sigma
    return 1.0 / (1.0 + r_scaled**2) ** 2

In [None]:
import plotly.graph_objects as go
from IPython.display import HTML

# Visualize weight functions.
r_vals = jnp.linspace(-5, 5, 200)

fig_weights = go.Figure()
fig_weights.add_trace(
    go.Scatter(
        x=r_vals,
        y=jnp.ones_like(r_vals),
        mode="lines",
        name="Standard LS (weight=1)",
        line=dict(color="gray", dash="dash"),
    )
)
fig_weights.add_trace(
    go.Scatter(
        x=r_vals,
        y=huber_weight(r_vals),
        mode="lines",
        name="Huber (k=1.345)",
        line=dict(color="#2196F3"),
    )
)
fig_weights.add_trace(
    go.Scatter(
        x=r_vals,
        y=cauchy_weight(r_vals),
        mode="lines",
        name="Cauchy (c=2.385)",
        line=dict(color="#4CAF50"),
    )
)
fig_weights.add_trace(
    go.Scatter(
        x=r_vals,
        y=geman_mcclure_weight(r_vals),
        mode="lines",
        name="Geman-McClure",
        line=dict(color="#FF9800"),
    )
)

fig_weights.update_layout(
    title="Robust Weight Functions",
    xaxis_title="Residual",
    yaxis_title="Weight",
    height=350,
    margin=dict(t=40, b=40, l=60, r=40),
    legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
)
HTML(fig_weights.to_html(full_html=False, include_plotlyjs="cdn"))

## IRLS with jaxls

[Iteratively reweighted least squares](https://en.wikipedia.org/wiki/Iteratively_reweighted_least_squares) (IRLS) is an algorithm for robust estimation. The key to implementing IRLS in jaxls is using `jax.lax.stop_gradient` on the weights. This tells JAX to treat the weights as constants when computing Jacobians, so the solver sees a weighted least squares problem at each iteration. As the solution changes across iterations, the weights automatically update, implementing IRLS without an explicit outer loop.

The cost function computes:
1. Residual $r_i = \|p - p_{\text{circle}}\|$ for each point
2. Weight $w_i = \psi(r_i)$ using the M-estimator, with `stop_gradient`
3. Weighted residual $\sqrt{w_i} \cdot r_i$

Since jaxls squares residuals, the objective becomes $\sum_i w_i r_i^2$, exactly the IRLS formulation.

In [None]:
def make_robust_circle_cost(weight_fn: callable):
    """Create a robust circle cost with IRLS weights via stop_gradient.

    Args:
        weight_fn: Function mapping residual magnitude to weight (e.g., cauchy_weight).

    Returns:
        A jaxls cost factory that applies robust weighting.
    """

    @jaxls.Cost.factory
    def robust_circle_residual(
        vals: jaxls.VarValues,
        circle: CircleVar,
        point: jax.Array,
    ) -> jax.Array:
        """2D residual with automatic IRLS weighting."""
        params = vals[circle]
        cx, cy, r = params[0], params[1], params[2]
        diff = point - jnp.array([cx, cy])
        dist = jnp.sqrt(jnp.sum(diff**2) + 1e-8)
        direction = diff / dist
        residual = dist - r

        # Compute weight from current residual, but stop gradients.
        # This makes the optimizer treat weights as constants, implementing IRLS.
        weight = jax.lax.stop_gradient(weight_fn(residual))
        sqrt_weight = jnp.sqrt(weight + 1e-8)

        return sqrt_weight * residual * direction

    return robust_circle_residual

In [None]:
def compute_weights(
    params: jax.Array, points: jax.Array, weight_fn: callable
) -> jax.Array:
    """Compute IRLS weights for visualization."""
    cx, cy, r = params[0], params[1], params[2]
    dists = jnp.sqrt((points[:, 0] - cx) ** 2 + (points[:, 1] - cy) ** 2 + 1e-8)
    residuals = dists - r
    return weight_fn(residuals)

## Robust fitting with Cauchy weights

With the `stop_gradient` approach, robust fitting is just a single `solve()` call:

In [None]:
# Create robust cost with Cauchy weights.
cauchy_weight_fn = lambda r: cauchy_weight(r, c=0.5)
robust_cost = make_robust_circle_cost(cauchy_weight_fn)

# Build and solve (just one call - IRLS happens automatically via solver iterations).
costs_irls = [robust_cost(CircleVar(id=jnp.zeros(n_points, dtype=jnp.int32)), points)]

initial_vals = jaxls.VarValues.make([circle_var.with_value(initial_params)])
problem_irls = jaxls.LeastSquaresProblem(costs_irls, [circle_var]).analyze()
solution_irls = problem_irls.solve(initial_vals)

params_irls = solution_irls[circle_var]
weights_irls = compute_weights(params_irls, points, cauchy_weight_fn)

print(
    f"IRLS result:     center=({params_irls[0]:.3f}, {params_irls[1]:.3f}), radius={params_irls[2]:.3f}"
)
print(f"True parameters: center=({true_cx:.3f}, {true_cy:.3f}), radius={true_r:.3f}")
print(
    f"\nCenter error: {jnp.sqrt((params_irls[0] - true_cx) ** 2 + (params_irls[1] - true_cy) ** 2):.4f}"
)
print(f"Radius error: {jnp.abs(params_irls[2] - true_r):.4f}")

## Visualization

Compare standard least squares (pulled by outliers) vs IRLS (robust to outliers):

In [None]:
from plotly.subplots import make_subplots

fig_compare = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Standard Least Squares", "IRLS (Cauchy)"),
)

# Common elements for both plots.
for col in [1, 2]:
    # Inlier points.
    fig_compare.add_trace(
        go.Scatter(
            x=all_x[~is_outlier],
            y=all_y[~is_outlier],
            mode="markers",
            marker=dict(size=8, color="#2196F3"),
            name="Inliers",
            showlegend=(col == 1),
        ),
        row=1,
        col=col,
    )

    # Outlier points.
    fig_compare.add_trace(
        go.Scatter(
            x=all_x[is_outlier],
            y=all_y[is_outlier],
            mode="markers",
            marker=dict(size=10, color="#F44336", symbol="x"),
            name="Outliers",
            showlegend=(col == 1),
        ),
        row=1,
        col=col,
    )

    # True circle.
    true_circle = make_circle_trace(
        true_cx, true_cy, true_r, "True circle", "#4CAF50", "dash"
    )
    true_circle.showlegend = col == 1
    fig_compare.add_trace(true_circle, row=1, col=col)

# Standard LS result.
fig_compare.add_trace(
    make_circle_trace(
        float(params_standard[0]),
        float(params_standard[1]),
        float(params_standard[2]),
        "Standard LS",
        "#FF9800",
    ),
    row=1,
    col=1,
)

# IRLS result.
fig_compare.add_trace(
    make_circle_trace(
        float(params_irls[0]),
        float(params_irls[1]),
        float(params_irls[2]),
        "IRLS",
        "#9C27B0",
    ),
    row=1,
    col=2,
)

fig_compare.update_xaxes(title_text="x", scaleanchor="y", scaleratio=1)
fig_compare.update_yaxes(title_text="y")
fig_compare.update_layout(
    height=450,
    margin=dict(t=40, b=40, l=60, r=40),
    legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
)
HTML(fig_compare.to_html(full_html=False, include_plotlyjs="cdn"))

## Weight visualization

IRLS identifies outliers by assigning them low weights. Let's visualize the final weights:

In [None]:
fig_weights_viz = go.Figure()

# Color points by weight (blue=high weight/inlier, red=low weight/outlier).
fig_weights_viz.add_trace(
    go.Scatter(
        x=all_x,
        y=all_y,
        mode="markers",
        marker=dict(
            size=12,
            color=weights_irls,
            colorscale=[[0, "#F44336"], [1, "#2196F3"]],
            colorbar=dict(title="Weight", thickness=15),
            cmin=0,
            cmax=1,
        ),
        text=[f"Weight: {w:.3f}" for w in weights_irls],
        hovertemplate="(%{x:.2f}, %{y:.2f})<br>%{text}<extra></extra>",
        name="Points",
    )
)

# Fitted circle.
fig_weights_viz.add_trace(
    make_circle_trace(
        float(params_irls[0]),
        float(params_irls[1]),
        float(params_irls[2]),
        "IRLS fit",
        "#9C27B0",
    ),
)

# True circle.
fig_weights_viz.add_trace(
    make_circle_trace(true_cx, true_cy, true_r, "True circle", "#4CAF50", "dash"),
)

fig_weights_viz.update_xaxes(title_text="x", scaleanchor="y", scaleratio=1)
fig_weights_viz.update_yaxes(title_text="y")
fig_weights_viz.update_layout(
    title="Points Colored by IRLS Weight",
    height=450,
    margin=dict(t=60, b=40, l=60, r=40),
    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
)
HTML(fig_weights_viz.to_html(full_html=False, include_plotlyjs="cdn"))

## Comparing M-estimators

Different weight functions give different levels of robustness:

In [None]:
# Fit with different M-estimators.
results = {"Standard LS": params_standard}

for name, weight_fn in [
    ("Huber", lambda r: huber_weight(r, k=0.5)),
    ("Cauchy", lambda r: cauchy_weight(r, c=0.5)),
    ("Geman-McClure", lambda r: geman_mcclure_weight(r, sigma=0.5)),
]:
    robust_cost = make_robust_circle_cost(weight_fn)
    costs = [robust_cost(CircleVar(id=jnp.zeros(n_points, dtype=jnp.int32)), points)]
    problem = jaxls.LeastSquaresProblem(costs, [circle_var]).analyze()
    solution = problem.solve(initial_vals, verbose=False)
    results[name] = solution[circle_var]

print(
    f"{'Method':<15} {'cx':>8} {'cy':>8} {'r':>8} {'Center err':>12} {'Radius err':>12}"
)
print("-" * 70)
print(
    f"{'True':<15} {true_cx:>8.3f} {true_cy:>8.3f} {true_r:>8.3f} {'-':>12} {'-':>12}"
)
for name, params in results.items():
    center_err = float(
        jnp.sqrt((params[0] - true_cx) ** 2 + (params[1] - true_cy) ** 2)
    )
    radius_err = float(jnp.abs(params[2] - true_r))
    print(
        f"{name:<15} {float(params[0]):>8.3f} {float(params[1]):>8.3f} {float(params[2]):>8.3f} {center_err:>12.4f} {radius_err:>12.4f}"
    )

In [None]:
# Visualize all methods.
colors = {
    "Standard LS": "#FF9800",
    "Huber": "#2196F3",
    "Cauchy": "#9C27B0",
    "Geman-McClure": "#00BCD4",
}

fig_all = go.Figure()

# Points.
fig_all.add_trace(
    go.Scatter(
        x=all_x[~is_outlier],
        y=all_y[~is_outlier],
        mode="markers",
        marker=dict(size=8, color="#607D8B"),
        name="Inliers",
    )
)
fig_all.add_trace(
    go.Scatter(
        x=all_x[is_outlier],
        y=all_y[is_outlier],
        mode="markers",
        marker=dict(size=10, color="#F44336", symbol="x"),
        name="Outliers",
    )
)

# True circle.
fig_all.add_trace(
    make_circle_trace(true_cx, true_cy, true_r, "True", "#4CAF50", "dash"),
)

# Fitted circles.
for name, params in results.items():
    fig_all.add_trace(
        make_circle_trace(
            float(params[0]), float(params[1]), float(params[2]), name, colors[name]
        ),
    )

fig_all.update_xaxes(title_text="x", scaleanchor="y", scaleratio=1)
fig_all.update_yaxes(title_text="y")
fig_all.update_layout(
    title="Comparison of M-Estimators for Circle Fitting",
    height=500,
    margin=dict(t=60, b=40, l=60, r=40),
    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
)
HTML(fig_all.to_html(full_html=False, include_plotlyjs="cdn"))

## Summary

Robust estimation in jaxls using `jax.lax.stop_gradient`:

1. Standard least squares gives outliers disproportionate influence due to the squared loss.

2. M-estimators (Huber, Cauchy, Geman-McClure) define weight functions that down-weight large residuals.

3. Using `stop_gradient` on weights implements IRLS automatically within the solver's iterations -- no explicit outer loop needed.

4. Huber provides convex robustness; Cauchy gives smooth down-weighting; Geman-McClure aggressively rejects outliers.

Implementation pattern:
```python
weight = jax.lax.stop_gradient(weight_fn(residual))
return jnp.sqrt(weight) * residual
```

This makes the optimizer treat weights as constants while still updating them as the solution evolves. The solver's iterations naturally perform IRLS.

For more details, see {class}`jaxls.Var`, {class}`jaxls.Cost`, and {class}`jaxls.LeastSquaresProblem`.