# Basics

An introduction to nonlinear least squares optimization with jaxls.

Features used:
- {class}`~jaxls.Var` subclassing for custom variable types
- {func}`@jaxls.Cost.factory <jaxls.Cost.factory>` for defining cost functions
- {class}`~jaxls.LeastSquaresProblem` for building and solving optimization problems
- {class}`~jaxls.VarValues` for accessing solution values

In [None]:
import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");

In [None]:
import jax
import jax.numpy as jnp
import jaxls

## The problem: circle fitting

Given a set of noisy 2D points that lie approximately on a circle, we want to find the circle parameters (center and radius) that best fit the data.

This is a classic nonlinear least squares problem: we define residual vectors that measure the error for each data point, then minimize the sum of squared residual norms. In jaxls, a "cost" is a term in this objective; each cost computes a residual vector, and the solver minimizes $\sum_i \|r_i(x)\|^2$.

In [None]:
# Generate noisy points on a circle.
true_cx, true_cy, true_r = 2.0, 1.5, 3.0  # True circle parameters.
num_points = 30
noise_std = 0.15

# Sample angles uniformly around the circle.
key = jax.random.PRNGKey(42)
angles = jnp.linspace(0, 2 * jnp.pi, num_points, endpoint=False)

# Generate points with Gaussian noise.
key, subkey = jax.random.split(key)
noise = jax.random.normal(subkey, shape=(num_points, 2)) * noise_std

points = (
    jnp.stack(
        [
            true_cx + true_r * jnp.cos(angles),
            true_cy + true_r * jnp.sin(angles),
        ],
        axis=-1,
    )
    + noise
)

print(f"Generated {num_points} noisy points around circle")
print(f"True parameters: center=({true_cx}, {true_cy}), radius={true_r}")

## Defining variables

Variables represent the unknowns we want to optimize. In jaxls, we define custom variable types by subclassing {class}`~jaxls.Var`.

Each variable type specifies:
- The data type it holds (any pytree, e.g., `jax.Array`, dataclasses, nested structures)
- A `default_factory` that creates an initial value

jaxls also supports non-Euclidean variables for optimization on manifolds like rotations. See [Non-Euclidean variables](advanced/non_euclidean.ipynb) for details.

In [None]:
class CircleVar(
    jaxls.Var[jax.Array],
    default_factory=lambda: jnp.array([0.0, 0.0, 1.0]),
):
    """Circle parameters: [cx, cy, radius]."""


# Create a variable instance with a unique ID.
circle_var = CircleVar(id=0)

print(f"Created variable: {circle_var}")
print(f"Default value: {CircleVar.default_factory()}")

## Defining costs

Cost functions define what we're optimizing. Use the {func}`@jaxls.Cost.factory <jaxls.Cost.factory>` decorator to create cost factories.

A cost function:
- Takes a {class}`~jaxls.VarValues` object as its first argument (for looking up variable values)
- Takes additional arguments (variables and/or static data)
- Returns a residual vector to minimize

The solver will minimize the sum of squared residuals across all costs.

In [None]:
@jaxls.Cost.factory
def circle_residual(
    vals: jaxls.VarValues,
    circle: CircleVar,
    point: jax.Array,
) -> jax.Array:
    """Residual for a point's distance to the circle.

    Args:
        vals: Container for looking up current variable values.
        circle: The circle variable to fit.
        point: A 2D point that should lie on the circle.

    Returns:
        2D residual vector pointing from closest circle point to the observed point.
    """
    params = vals[circle]  # Look up the current circle parameters.
    cx, cy, r = params[0], params[1], params[2]

    # Vector from center to point.
    diff = point - jnp.array([cx, cy])
    dist_to_center = jnp.sqrt(jnp.sum(diff**2) + 1e-6)
    direction = diff / dist_to_center

    # 2D residual: error vector from closest circle point to observed point.
    return (dist_to_center - r) * direction

Note that `circle_residual` is now a factory that creates cost objects. We call it with variable instances and static data to create costs:

In [None]:
# Create a cost for each observed point.
costs = [circle_residual(circle_var, point) for point in points]

print(f"Created {len(costs)} cost objects")
print(f"Example cost: {costs[0]}")

For large problems, batched construction (passing arrays of IDs and data) is more efficient. See {doc}`tips_and_gotchas` for details.

## Building problems

A {class}`~jaxls.LeastSquaresProblem` bundles costs and variables together.

We can call {meth}`.show() <jaxls.LeastSquaresProblem.show>` to visualize the problem structure. In notebooks this displays inline; outside notebooks it opens in a browser.

In [None]:
# Build the problem.
problem = jaxls.LeastSquaresProblem(costs, [circle_var])

# Visualize the problem structure.
problem.show()

In [None]:
# Analyze the problem structure for efficient solving.
problem = problem.analyze()

## Solving

Call `.solve()` to run the Levenberg-Marquardt optimizer. The solver iteratively adjusts the variables to minimize the sum of squared residuals.

In [None]:
# Solve the problem (uses default initial values from the variable's default_factory).
solution = problem.solve()

# Access the solution using the variable as a key.
result = solution[circle_var]
est_cx, est_cy, est_r = result[0], result[1], result[2]

print(
    f"\nEstimated: center=({float(est_cx):.3f}, {float(est_cy):.3f}), radius={float(est_r):.3f}"
)
print(f"True:      center=({true_cx:.3f}, {true_cy:.3f}), radius={true_r:.3f}")
print(
    f"\nCenter error: {float(jnp.sqrt((est_cx - true_cx) ** 2 + (est_cy - true_cy) ** 2)):.4f}"
)
print(f"Radius error: {float(abs(est_r - true_r)):.4f}")

## Visualization

In [None]:
import plotly.graph_objects as go
from IPython.display import HTML


def make_circle_trace(
    cx: float, cy: float, r: float, name: str, color: str, dash: str = "solid"
) -> go.Scatter:
    """Create a Plotly trace for a circle.

    Args:
        cx: Circle center x coordinate.
        cy: Circle center y coordinate.
        r: Circle radius.
        name: Legend name.
        color: Line color.
        dash: Line dash style.

    Returns:
        Plotly Scatter trace for the circle.
    """
    theta = jnp.linspace(0, 2 * jnp.pi, 100)
    x = cx + r * jnp.cos(theta)
    y = cy + r * jnp.sin(theta)
    return go.Scatter(
        x=x,
        y=y,
        mode="lines",
        name=name,
        line=dict(color=color, width=2, dash=dash),
    )


fig = go.Figure()

# Noisy data points.
fig.add_trace(
    go.Scatter(
        x=points[:, 0],
        y=points[:, 1],
        mode="markers",
        name="Noisy points",
        marker=dict(size=8, color="steelblue"),
        hovertemplate="(%{x:.2f}, %{y:.2f})",
    )
)

# True circle.
fig.add_trace(
    make_circle_trace(true_cx, true_cy, true_r, "True circle", "green", "dash")
)

# Fitted circle.
fig.add_trace(
    make_circle_trace(
        float(est_cx), float(est_cy), float(est_r), "Fitted circle", "crimson"
    )
)

# Center points.
fig.add_trace(
    go.Scatter(
        x=[true_cx, float(est_cx)],
        y=[true_cy, float(est_cy)],
        mode="markers",
        marker=dict(size=10, symbol="x", color=["green", "crimson"]),
        name="Centers",
        showlegend=False,
    )
)

fig.update_layout(
    title="Circle Fitting with jaxls",
    xaxis_title="x",
    yaxis_title="y",
    xaxis=dict(scaleanchor="y", scaleratio=1),
    height=450,
    margin=dict(t=40, b=40, l=40, r=40),
    legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
)

HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

## Summary

The key steps for solving nonlinear least squares problems with jaxls:

1. Define variables by subclassing {class}`~jaxls.Var` with a `default_factory`
2. Define costs using {func}`@jaxls.Cost.factory <jaxls.Cost.factory>`, returning residual vectors
3. Build the problem with {class}`~jaxls.LeastSquaresProblem` and call `.analyze()`
4. Solve with `.solve()` and access results via the {class}`~jaxls.VarValues` object

For more, see:
- {doc}`tips_and_gotchas`: Batched construction, residual dimensions, solver selection
- {doc}`advanced/constraints`: Equality and inequality constraints
- {doc}`advanced/non_euclidean`: Lie group variables for rotations and poses
- {doc}`advanced/robust_costs`: Handling outliers with M-estimators