Spring equilibrium#

In this notebook, we solve a spring network equilibrium problem: finding the resting shape of a mesh of masses connected by springs under gravity.

Features used:

  • Var subclassing for custom variable types

  • @jaxls.Cost.factory for defining costs

  • Equality constraints (constraint_eq_zero) for anchoring points

  • Augmented Lagrangian solver for constrained optimization

We show two approaches:

  1. Naive construction: Build variables and costs one-by-one (simple but slower)

  2. Batched construction: Use array operations for efficiency (recommended for large problems)

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import jax
import jax.numpy as jnp
import jaxls

Variables and costs#

Variables represent the unknowns we want to solve for. We define a custom 2D point variable by subclassing jaxls.Var:

class Point2Var(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(2)):
    """A 2D point variable."""

Cost functions define what we’re optimizing. Use @jaxls.Cost.factory to create cost factories.

A cost function takes a VarValues object (for looking up variable values) plus any arguments, and returns a residual vector to minimize or constrain.

@jaxls.Cost.factory
def spring_cost(
    vals: jaxls.VarValues,
    var_a: Point2Var,
    var_b: Point2Var,
    rest_length: float,
) -> jax.Array:
    """Penalize deviation from rest length.

    Returns a 2D residual for better Jacobian conditioning.
    """
    diff = vals[var_a] - vals[var_b]
    length = jnp.sqrt(jnp.sum(diff**2) + 1e-6)
    direction = diff / length
    return (length - rest_length) * direction


@jaxls.Cost.factory(kind="constraint_eq_zero")
def anchor_constraint(
    vals: jaxls.VarValues,
    var: Point2Var,
    target: jax.Array,
) -> jax.Array:
    """Pin a point to a target position (hard constraint)."""
    return vals[var] - target


@jaxls.Cost.factory
def gravity_cost(
    vals: jaxls.VarValues,
    var: Point2Var,
) -> jax.Array:
    """Pull points downward."""
    return (vals[var][1] + 10.0) * 0.15

Problem setup#

We’ll create a 5x4 grid of points. The left and right columns will be anchored, and springs connect adjacent points:

# Grid dimensions.
cols, rows = 5, 4
num_points = cols * rows
spacing = 1.0


def idx(row: int, col: int) -> int:
    """Convert (row, col) to flat index.

    Args:
        row: Row index
        col: Column index

    Returns:
        Flat index into the points array
    """
    return row * cols + col


# Initial positions (regular grid)
initial_positions = jnp.array(
    [[c * spacing, -r * spacing] for r in range(rows) for c in range(cols)]
)

print(f"Grid: {cols}x{rows} = {num_points} points")
Grid: 5x4 = 20 points

Approach 1: naive construction#

The straightforward approach: create variables and costs one-by-one using Python loops.

# Create variables for each grid point.
point_vars_naive = {
    (r, c): Point2Var(id=idx(r, c)) for r in range(rows) for c in range(cols)
}

costs_naive: list[jaxls.Cost] = []

# Anchor the left and right columns.
for row in range(rows):
    costs_naive.append(
        anchor_constraint(point_vars_naive[(row, 0)], initial_positions[idx(row, 0)])
    )
    costs_naive.append(
        anchor_constraint(
            point_vars_naive[(row, cols - 1)], initial_positions[idx(row, cols - 1)]
        )
    )

# Horizontal springs.
for row in range(rows):
    for col in range(cols - 1):
        costs_naive.append(
            spring_cost(
                point_vars_naive[(row, col)], point_vars_naive[(row, col + 1)], spacing
            )
        )

# Vertical springs.
for row in range(rows - 1):
    for col in range(cols):
        costs_naive.append(
            spring_cost(
                point_vars_naive[(row, col)], point_vars_naive[(row + 1, col)], spacing
            )
        )

# Gravity on non-anchor points.
for row in range(rows):
    for col in range(1, cols - 1):
        costs_naive.append(gravity_cost(point_vars_naive[(row, col)]))

print(f"Naive: Created {len(costs_naive)} cost objects")
Naive: Created 51 cost objects

Approach 2: batched construction#

When variable IDs have a leading batch axis, a batch of costs is created in one call. This reduces Python overhead and speeds up problem analysis.

# Create ALL point variables at once with batched IDs.
all_point_vars = Point2Var(id=jnp.arange(num_points))

# Precompute index arrays.
anchor_indices = jnp.array([idx(r, c) for r in range(rows) for c in [0, cols - 1]])
anchor_positions = initial_positions[anchor_indices]

h_spring_a = jnp.array([idx(r, c) for r in range(rows) for c in range(cols - 1)])
h_spring_b = jnp.array([idx(r, c + 1) for r in range(rows) for c in range(cols - 1)])

v_spring_a = jnp.array([idx(r, c) for r in range(rows - 1) for c in range(cols)])
v_spring_b = jnp.array([idx(r + 1, c) for r in range(rows - 1) for c in range(cols)])

interior_indices = jnp.array(
    [idx(r, c) for r in range(rows) for c in range(1, cols - 1)]
)

# Create all costs using batched construction (4 calls instead of 51!)
costs_batched: list[jaxls.Cost] = [
    anchor_constraint(Point2Var(id=anchor_indices), anchor_positions),
    spring_cost(Point2Var(id=h_spring_a), Point2Var(id=h_spring_b), spacing),
    spring_cost(Point2Var(id=v_spring_a), Point2Var(id=v_spring_b), spacing),
    gravity_cost(Point2Var(id=interior_indices)),
]

print(
    f"Batched: Created {len(costs_batched)} cost objects (representing 51 individual costs)"
)
Batched: Created 4 cost objects (representing 51 individual costs)
# Visualize the full batched problem structure.
# The interactive D3.js visualization supports zoom, pan, and dragging nodes.
jaxls.LeastSquaresProblem(costs_batched, [all_point_vars]).show()

Solving#

Both approaches produce identical results. We’ll use the batched version:

# Create initial values.
initial_vals = jaxls.VarValues.make([all_point_vars.with_value(initial_positions)])

# Build and solve.
problem = jaxls.LeastSquaresProblem(costs_batched, [all_point_vars]).analyze()
solution = problem.solve(initial_vals)
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost
INFO     | Augmented Lagrangian: initial snorm=0.0000e+00, csupn=0.0000e+00, max_rho=1.9845e+02, constraint_dim=16
INFO     |  step #0: cost=19.8450 lambd=0.0005 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00000 (avg 0.00000)
INFO     |      - spring_cost(31): 0.00000 (avg 0.00000)
INFO     |      - gravity_cost(12): 19.84501 (avg 1.65375)
INFO     |  step #1: cost=19.8450 lambd=0.0010 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00000 (avg 0.00000)
INFO     |      - spring_cost(31): 0.00000 (avg 0.00000)
INFO     |      - gravity_cost(12): 19.84501 (avg 1.65375)
INFO     |  step #2: cost=19.8450 lambd=0.0020 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00000 (avg 0.00000)
INFO     |      - spring_cost(31): 0.00000 (avg 0.00000)
INFO     |      - gravity_cost(12): 19.84501 (avg 1.65375)
INFO     |  step #3: cost=19.8450 lambd=0.0040 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00000 (avg 0.00000)
INFO     |      - spring_cost(31): 0.00000 (avg 0.00000)
INFO     |      - gravity_cost(12): 19.84501 (avg 1.65375)
INFO     |  step #4: cost=19.8450 lambd=0.0080 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00000 (avg 0.00000)
INFO     |      - spring_cost(31): 0.00000 (avg 0.00000)
INFO     |      - gravity_cost(12): 19.84501 (avg 1.65375)
INFO     |  step #5: cost=19.8450 lambd=0.0160 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00000 (avg 0.00000)
INFO     |      - spring_cost(31): 0.00000 (avg 0.00000)
INFO     |      - gravity_cost(12): 19.84501 (avg 1.65375)
INFO     |  step #6: cost=19.8450 lambd=0.0320 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00000 (avg 0.00000)
INFO     |      - spring_cost(31): 0.00000 (avg 0.00000)
INFO     |      - gravity_cost(12): 19.84501 (avg 1.65375)
INFO     |  step #7: cost=19.8450 lambd=0.0640 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00000 (avg 0.00000)
INFO     |      - spring_cost(31): 0.00000 (avg 0.00000)
INFO     |      - gravity_cost(12): 19.84501 (avg 1.65375)
INFO     |  step #8: cost=19.8450 lambd=0.1280 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00000 (avg 0.00000)
INFO     |      - spring_cost(31): 0.00000 (avg 0.00000)
INFO     |      - gravity_cost(12): 19.84501 (avg 1.65375)
INFO     |  step #9: cost=19.8450 lambd=0.2560 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00000 (avg 0.00000)
INFO     |      - spring_cost(31): 0.00000 (avg 0.00000)
INFO     |      - gravity_cost(12): 19.84501 (avg 1.65375)
INFO     |      accepted=True ATb_norm=9.73e-01 cost_prev=19.8450 cost_new=17.7450
INFO     |  step #10: cost=17.7450 lambd=0.1280 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00000 (avg 0.00000)
INFO     |      - spring_cost(31): 3.55096 (avg 0.05727)
INFO     |      - gravity_cost(12): 14.19406 (avg 1.18284)
INFO     |  step #11: cost=17.7450 lambd=0.2560 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00000 (avg 0.00000)
INFO     |      - spring_cost(31): 3.55096 (avg 0.05727)
INFO     |      - gravity_cost(12): 14.19406 (avg 1.18284)
INFO     |  step #12: cost=17.7450 lambd=0.5120 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00000 (avg 0.00000)
INFO     |      - spring_cost(31): 3.55096 (avg 0.05727)
INFO     |      - gravity_cost(12): 14.19406 (avg 1.18284)
INFO     |      accepted=True ATb_norm=3.02e+00 cost_prev=17.7450 cost_new=16.9602
INFO     |  step #13: cost=16.9561 lambd=0.2560 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00403 (avg 0.00025)
INFO     |      - spring_cost(31): 2.67261 (avg 0.04311)
INFO     |      - gravity_cost(12): 14.28351 (avg 1.19029)
INFO     |      accepted=True ATb_norm=1.77e+00 cost_prev=16.9602 cost_new=16.2103
INFO     |  step #14: cost=16.2057 lambd=0.1280 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00456 (avg 0.00029)
INFO     |      - spring_cost(31): 1.61689 (avg 0.02608)
INFO     |      - gravity_cost(12): 14.58886 (avg 1.21574)
INFO     |      accepted=True ATb_norm=3.49e-01 cost_prev=16.2103 cost_new=16.1663
INFO     |  step #15: cost=16.1621 lambd=0.0640 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00414 (avg 0.00026)
INFO     |      - spring_cost(31): 1.27630 (avg 0.02059)
INFO     |      - gravity_cost(12): 14.88583 (avg 1.24049)
INFO     |  step #16: cost=16.1621 lambd=0.1280 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00414 (avg 0.00026)
INFO     |      - spring_cost(31): 1.27630 (avg 0.02059)
INFO     |      - gravity_cost(12): 14.88583 (avg 1.24049)
INFO     |  step #17: cost=16.1621 lambd=0.2560 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00414 (avg 0.00026)
INFO     |      - spring_cost(31): 1.27630 (avg 0.02059)
INFO     |      - gravity_cost(12): 14.88583 (avg 1.24049)
INFO     |  step #18: cost=16.1621 lambd=0.5120 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00414 (avg 0.00026)
INFO     |      - spring_cost(31): 1.27630 (avg 0.02059)
INFO     |      - gravity_cost(12): 14.88583 (avg 1.24049)
INFO     |      accepted=True ATb_norm=1.89e-01 cost_prev=16.1663 cost_new=16.1507
INFO     |  step #19: cost=16.1464 lambd=0.2560 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00431 (avg 0.00027)
INFO     |      - spring_cost(31): 1.29705 (avg 0.02092)
INFO     |      - gravity_cost(12): 14.84931 (avg 1.23744)
INFO     |  step #20: cost=16.1464 lambd=0.5120 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00431 (avg 0.00027)
INFO     |      - spring_cost(31): 1.29705 (avg 0.02092)
INFO     |      - gravity_cost(12): 14.84931 (avg 1.23744)
INFO     |      accepted=True ATb_norm=1.08e-01 cost_prev=16.1507 cost_new=16.1459
INFO     |  step #21: cost=16.1416 lambd=0.2560 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00422 (avg 0.00026)
INFO     |      - spring_cost(31): 1.26018 (avg 0.02033)
INFO     |      - gravity_cost(12): 14.88145 (avg 1.24012)
INFO     |  step #22: cost=16.1416 lambd=0.5120 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00422 (avg 0.00026)
INFO     |      - spring_cost(31): 1.26018 (avg 0.02033)
INFO     |      - gravity_cost(12): 14.88145 (avg 1.24012)
INFO     |      accepted=True ATb_norm=6.47e-02 cost_prev=16.1459 cost_new=16.1444
INFO     |  step #23: cost=16.1401 lambd=0.2560 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00429 (avg 0.00027)
INFO     |      - spring_cost(31): 1.26015 (avg 0.02033)
INFO     |      - gravity_cost(12): 14.87994 (avg 1.24000)
INFO     |  step #24: cost=16.1401 lambd=0.5120 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00429 (avg 0.00027)
INFO     |      - spring_cost(31): 1.26015 (avg 0.02033)
INFO     |      - gravity_cost(12): 14.87994 (avg 1.24000)
INFO     |      accepted=True ATb_norm=3.27e-02 cost_prev=16.1444 cost_new=16.1439
INFO     |  step #25: cost=16.1397 lambd=0.2560 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00424 (avg 0.00027)
INFO     |      - spring_cost(31): 1.25002 (avg 0.02016)
INFO     |      - gravity_cost(12): 14.88967 (avg 1.24081)
INFO     |      accepted=False ATb_norm=1.98e-02 cost_prev=16.1439 cost_new=16.1441
INFO     |  AL update: snorm=1.3472e-03, csupn=1.3472e-03, max_rho=7.9380e+02
INFO     |  step #26: cost=16.1397 lambd=0.2560 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.02651 (avg 0.00166)
INFO     |      - spring_cost(31): 1.25002 (avg 0.02016)
INFO     |      - gravity_cost(12): 14.88967 (avg 1.24081)
INFO     |      accepted=True ATb_norm=3.91e+00 cost_prev=16.1662 cost_new=16.1491
INFO     |  step #27: cost=16.1481 lambd=0.1280 inexact_tol=1.0e-02
INFO     |      - augmented_anchor_constraint(8): 0.00105 (avg 0.00007)
INFO     |      - spring_cost(31): 1.26088 (avg 0.02034)
INFO     |      - gravity_cost(12): 14.88720 (avg 1.24060)
INFO     |      accepted=False ATb_norm=1.88e-02 cost_prev=16.1491 cost_new=16.1492
INFO     |  AL update: snorm=4.4343e-06, csupn=4.4343e-06, max_rho=7.9380e+02
INFO     | Terminated @ iteration #28: cost=16.1481 criteria=[1 0 0], term_deltas=6.3e-06,6.7e-03,6.0e-04

Visualization#

Hide code cell source

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import HTML


def get_traces(vals: jaxls.VarValues) -> list[go.Scatter]:
    """Get all Plotly traces for visualization.

    Args:
        vals: Variable values containing point positions

    Returns:
        List of Plotly Scatter traces for springs and points
    """
    positions = vals[all_point_vars]
    traces = []

    # Spring lines.
    for a, b in [(h_spring_a, h_spring_b), (v_spring_a, v_spring_b)]:
        for i, j in zip(a, b):
            p1, p2 = positions[i], positions[j]
            traces.append(
                go.Scatter(
                    x=[p1[0], p2[0]],
                    y=[p1[1], p2[1]],
                    mode="lines",
                    line=dict(color="gray", width=1),
                    opacity=0.5,
                    hoverinfo="skip",
                    showlegend=False,
                )
            )

    # Anchor points.
    anchor_pos = positions[anchor_indices]
    traces.append(
        go.Scatter(
            x=anchor_pos[:, 0],
            y=anchor_pos[:, 1],
            mode="markers",
            marker=dict(size=12, color="crimson"),
            name="Anchors",
            hovertemplate="(%{x:.2f}, %{y:.2f})",
        )
    )

    # Free points.
    free_pos = positions[interior_indices]
    traces.append(
        go.Scatter(
            x=free_pos[:, 0],
            y=free_pos[:, 1],
            mode="markers",
            marker=dict(size=10, color="steelblue"),
            name="Free points",
            hovertemplate="(%{x:.2f}, %{y:.2f})",
        )
    )
    return traces


fig = make_subplots(rows=1, cols=2, subplot_titles=("Initial", "Optimized"))
for trace in get_traces(initial_vals):
    fig.add_trace(trace, row=1, col=1)
for trace in get_traces(solution):
    fig.add_trace(trace, row=1, col=2)
fig.update_xaxes(title_text="x", range=[-0.5, 4.5], scaleanchor="y", scaleratio=1)
fig.update_yaxes(title_text="y", range=[-5, 0.5])
fig.update_layout(height=400, showlegend=False, margin=dict(t=40, b=40, l=40, r=40))
HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

Relaxation animation#

We can visualize the relaxation process by solving a sequence of problems with progressively increasing gravity. This simulates how the network would settle from its initial configuration to the final equilibrium:

@jaxls.Cost.factory
def gravity_cost_scaled(
    vals: jaxls.VarValues,
    var: Point2Var,
    scale: float,
) -> jax.Array:
    """Pull points downward with adjustable strength."""
    return (vals[var][1] + 10.0) * 0.15 * scale


# Solve for a sequence of gravity strengths.
n_frames = 20
gravity_scales = jnp.linspace(0.0, 1.0, n_frames)
frame_solutions = [initial_vals]

current_vals = initial_vals
for scale in gravity_scales[1:]:
    costs_scaled: list[jaxls.Cost] = [
        anchor_constraint(Point2Var(id=anchor_indices), anchor_positions),
        spring_cost(Point2Var(id=h_spring_a), Point2Var(id=h_spring_b), spacing),
        spring_cost(Point2Var(id=v_spring_a), Point2Var(id=v_spring_b), spacing),
        gravity_cost_scaled(Point2Var(id=interior_indices), float(scale)),
    ]
    problem_scaled = jaxls.LeastSquaresProblem(costs_scaled, [all_point_vars]).analyze()
    current_vals = problem_scaled.solve(current_vals, verbose=False)
    frame_solutions.append(current_vals)

print(f"Generated {len(frame_solutions)} frames")
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Building optimization problem with 51 terms and 20 variables: 43 costs, 8 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 31 costs, 2 variables each: spring_cost
INFO     | Vectorizing group with 12 costs, 1 variables each: gravity_cost_scaled
INFO     | Vectorizing constraint group with 8 constraints (constraint_eq_zero), 1 variables each: augmented_anchor_constraint
Generated 20 frames

Hide code cell source

def get_frame_data(vals: jaxls.VarValues) -> dict:
    """Extract position data for animation frame."""
    positions = vals[all_point_vars]

    # Build spring line coordinates.
    spring_x: list[float | None] = []
    spring_y: list[float | None] = []
    for a, b in [(h_spring_a, h_spring_b), (v_spring_a, v_spring_b)]:
        for i, j in zip(a, b):
            p1, p2 = positions[i], positions[j]
            spring_x.extend([float(p1[0]), float(p2[0]), None])
            spring_y.extend([float(p1[1]), float(p2[1]), None])

    anchor_pos = positions[anchor_indices]
    free_pos = positions[interior_indices]

    return dict(
        spring_x=spring_x,
        spring_y=spring_y,
        anchor_x=[float(x) for x in anchor_pos[:, 0]],
        anchor_y=[float(y) for y in anchor_pos[:, 1]],
        free_x=[float(x) for x in free_pos[:, 0]],
        free_y=[float(y) for y in free_pos[:, 1]],
    )


# Build animation frames.
frames = []
for i, vals in enumerate(frame_solutions):
    data = get_frame_data(vals)
    frames.append(
        go.Frame(
            data=[
                go.Scatter(
                    x=data["spring_x"],
                    y=data["spring_y"],
                    mode="lines",
                    line=dict(color="gray", width=2),
                    hoverinfo="skip",
                ),
                go.Scatter(
                    x=data["anchor_x"],
                    y=data["anchor_y"],
                    mode="markers",
                    marker=dict(size=12, color="crimson"),
                ),
                go.Scatter(
                    x=data["free_x"],
                    y=data["free_y"],
                    mode="markers",
                    marker=dict(size=10, color="steelblue"),
                ),
            ],
            name=str(i),
        )
    )

# Initial frame data.
init_data = get_frame_data(initial_vals)

fig_anim = go.Figure(
    data=[
        go.Scatter(
            x=init_data["spring_x"],
            y=init_data["spring_y"],
            mode="lines",
            line=dict(color="gray", width=2),
            hoverinfo="skip",
            name="Springs",
        ),
        go.Scatter(
            x=init_data["anchor_x"],
            y=init_data["anchor_y"],
            mode="markers",
            marker=dict(size=12, color="crimson"),
            name="Anchors",
        ),
        go.Scatter(
            x=init_data["free_x"],
            y=init_data["free_y"],
            mode="markers",
            marker=dict(size=10, color="steelblue"),
            name="Free points",
        ),
    ],
    frames=frames,
    layout=go.Layout(
        title="Spring network relaxation",
        xaxis=dict(range=[-0.5, 4.5], title="x", scaleanchor="y", scaleratio=1),
        yaxis=dict(range=[-5, 0.5], title="y"),
        updatemenus=[
            dict(
                type="buttons",
                showactive=False,
                y=1.15,
                x=0.5,
                xanchor="center",
                buttons=[
                    dict(
                        label="Play",
                        method="animate",
                        args=[
                            None,
                            dict(
                                frame=dict(duration=100, redraw=True),
                                fromcurrent=True,
                                transition=dict(duration=50),
                            ),
                        ],
                    ),
                    dict(
                        label="Pause",
                        method="animate",
                        args=[
                            [None],
                            dict(
                                frame=dict(duration=0, redraw=False), mode="immediate"
                            ),
                        ],
                    ),
                ],
            )
        ],
        sliders=[
            dict(
                active=0,
                yanchor="top",
                xanchor="left",
                currentvalue=dict(
                    prefix="Gravity: ", suffix="%", visible=True, xanchor="center"
                ),
                pad=dict(b=10, t=50),
                steps=[
                    dict(
                        args=[
                            [str(i)],
                            dict(
                                frame=dict(duration=0, redraw=True),
                                mode="immediate",
                                transition=dict(duration=0),
                            ),
                        ],
                        label=f"{int(float(gravity_scales[i]) * 100)}",
                        method="animate",
                    )
                    for i in range(n_frames)
                ],
                x=0.1,
                y=0,
                len=0.8,
            )
        ],
        height=450,
        showlegend=False,
        margin=dict(t=80, b=80),
    ),
)
HTML(fig_anim.to_html(full_html=False, include_plotlyjs="cdn", auto_play=False))

The animation shows how the spring network relaxes as gravity is progressively applied. Starting from the regular grid (0% gravity), the interior points sag downward until reaching the final equilibrium (100% gravity).

When to use each construction approach:

  • Naive: Simpler code, good for prototyping and small problems

  • Batched: Faster problem analysis, essential for large-scale problems

For more details, see jaxls.Var, jaxls.Cost, and jaxls.LeastSquaresProblem.