Obstacle avoidance#

In this notebook, we solve an obstacle avoidance problem: finding a smooth path from start to goal while avoiding circular obstacles.

Features used:

  • Var for 2D waypoint positions

  • @jaxls.Cost.factory for smoothness and anchor costs

  • Inequality constraints (constraint_leq_zero): obstacle avoidance

  • Augmented Lagrangian solver for constrained optimization

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import jax
import jax.numpy as jnp
import jaxls

Problem setup#

Plan a trajectory with 20 waypoints from start to goal, avoiding two obstacles:

# Trajectory parameters.
n_waypoints = 20
start = jnp.array([0.0, 0.0])
goal = jnp.array([10.0, 0.0])

# Obstacles: (center_x, center_y, radius)
obstacles = [
    (3.5, 1.0, 1.5),
    (6.5, -1.0, 1.5),
]

Variables and costs#

Variable instances are PyTrees. We define waypoints and cost functions:

class WaypointVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(2)):
    """2D waypoint position."""


# Batched variable creation.
waypoint_vars = WaypointVar(id=jnp.arange(n_waypoints))
@jaxls.Cost.factory
def smoothness_cost(
    vals: jaxls.VarValues,
    var_prev: WaypointVar,
    var_curr: WaypointVar,
    var_next: WaypointVar,
) -> jax.Array:
    """Penalize acceleration (second derivative)."""
    accel = vals[var_prev] - 2 * vals[var_curr] + vals[var_next]
    return accel * 5.0  # Weight for smoothness


@jaxls.Cost.factory
def anchor_cost(
    vals: jaxls.VarValues,
    var: WaypointVar,
    target: jax.Array,
) -> jax.Array:
    """Pin waypoint to target position."""
    return (vals[var] - target) * 10.0  # Strong weight


@jaxls.Cost.factory(kind="constraint_leq_zero")
def obstacle_constraint(
    vals: jaxls.VarValues,
    var: WaypointVar,
    center: jax.Array,
    radius: float,
) -> jax.Array:
    """Stay outside obstacle: ||p - center|| >= radius."""
    diff = vals[var] - center
    dist = jnp.sqrt(jnp.sum(diff**2) + 1e-6)
    # Constraint: radius - dist <= 0 (i.e., dist >= radius)
    return jnp.array([radius - dist])

Problem construction#

# Indices for smoothness costs (interior waypoints)
smooth_prev = jnp.arange(0, n_waypoints - 2)  # 0, 1, ..., n-3
smooth_curr = jnp.arange(1, n_waypoints - 1)  # 1, 2, ..., n-2
smooth_next = jnp.arange(2, n_waypoints)  # 2, 3, ..., n-1

# Obstacle constraint: each waypoint x each obstacle.
n_obstacles = len(obstacles)
waypoint_ids_for_obs = jnp.tile(jnp.arange(n_waypoints), n_obstacles)
obstacle_centers = jnp.array([[cx, cy] for cx, cy, r in obstacles])
obstacle_radii = jnp.array([r for cx, cy, r in obstacles])
# Repeat each obstacle for all waypoints.
centers_repeated = jnp.repeat(obstacle_centers, n_waypoints, axis=0)
radii_repeated = jnp.repeat(obstacle_radii, n_waypoints)

# Build costs using batched construction.
costs: list[jaxls.Cost] = [
    # Anchor start and goal.
    anchor_cost(WaypointVar(id=0), start),
    anchor_cost(WaypointVar(id=n_waypoints - 1), goal),
    # Smoothness costs (batched)
    smoothness_cost(
        WaypointVar(id=smooth_prev),
        WaypointVar(id=smooth_curr),
        WaypointVar(id=smooth_next),
    ),
    # Obstacle avoidance constraints (batched)
    obstacle_constraint(
        WaypointVar(id=waypoint_ids_for_obs),
        centers_repeated,
        radii_repeated,
    ),
]

print(f"Created {len(costs)} batched cost objects")
Created 4 batched cost objects
# Initialize with straight line (will collide with obstacles)
t = jnp.linspace(0, 1, n_waypoints)
initial_positions = start + t[:, None] * (goal - start)  # (n_waypoints, 2)

initial_vals = jaxls.VarValues.make([waypoint_vars.with_value(initial_positions)])

# Create the problem.
problem = jaxls.LeastSquaresProblem(costs, [waypoint_vars])

# Visualize the problem structure structure.
problem.show()
# Analyze the problem.
problem = problem.analyze()
INFO     | Building optimization problem with 60 terms and 20 variables: 20 costs, 0 eq_zero, 40 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 2 costs, 1 variables each: anchor_cost
INFO     | Vectorizing constraint group with 40 constraints (constraint_leq_zero), 1 variables each: augmented_obstacle_constraint
INFO     | Vectorizing group with 18 costs, 3 variables each: smoothness_cost

Solving#

solution = problem.solve(initial_vals)
INFO     | Augmented Lagrangian: initial snorm=4.8317e-01, csupn=4.8317e-01, max_rho=1.0706e+01, constraint_dim=40
INFO     |  step #0: cost=0.0000 lambd=0.0005 inexact_tol=1.0e-02
INFO     |      - anchor_cost(2): 0.00000 (avg 0.00000)
INFO     |      - augmented_obstacle_constraint(40): 11.46163 (avg 0.28654)
INFO     |      - smoothness_cost(18): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.19e+01 cost_prev=11.4616 cost_new=0.9302
INFO     |  step #1: cost=0.8331 lambd=0.0003 inexact_tol=1.0e-02
INFO     |      - anchor_cost(2): 0.00119 (avg 0.00030)
INFO     |      - augmented_obstacle_constraint(40): 0.09714 (avg 0.00243)
INFO     |      - smoothness_cost(18): 0.83190 (avg 0.02311)
INFO     |      accepted=True ATb_norm=9.96e-01 cost_prev=0.9302 cost_new=0.8834
INFO     |  step #2: cost=0.7775 lambd=0.0001 inexact_tol=6.3e-03
INFO     |      - anchor_cost(2): 0.00125 (avg 0.00031)
INFO     |      - augmented_obstacle_constraint(40): 0.10585 (avg 0.00265)
INFO     |      - smoothness_cost(18): 0.77627 (avg 0.02156)
INFO     |      accepted=True ATb_norm=4.80e-02 cost_prev=0.8834 cost_new=0.8828
INFO     |  step #3: cost=0.7712 lambd=0.0001 inexact_tol=2.1e-03
INFO     |      - anchor_cost(2): 0.00122 (avg 0.00030)
INFO     |      - augmented_obstacle_constraint(40): 0.11158 (avg 0.00279)
INFO     |      - smoothness_cost(18): 0.76997 (avg 0.02139)
INFO     |      accepted=True ATb_norm=6.31e-03 cost_prev=0.8828 cost_new=0.8828
INFO     |  AL update: snorm=7.0361e-02, csupn=7.0361e-02, max_rho=1.0706e+01
INFO     |  step #4: cost=0.7707 lambd=0.0000 inexact_tol=2.1e-03
INFO     |      - anchor_cost(2): 0.00122 (avg 0.00030)
INFO     |      - augmented_obstacle_constraint(40): 0.44837 (avg 0.01121)
INFO     |      - smoothness_cost(18): 0.76945 (avg 0.02137)
INFO     |      accepted=True ATb_norm=1.18e+00 cost_prev=1.2190 cost_new=1.1491
INFO     |  step #5: cost=0.9132 lambd=0.0000 inexact_tol=2.1e-03
INFO     |      - anchor_cost(2): 0.00134 (avg 0.00034)
INFO     |      - augmented_obstacle_constraint(40): 0.23599 (avg 0.00590)
INFO     |      - smoothness_cost(18): 0.91181 (avg 0.02533)
INFO     |      accepted=True ATb_norm=1.96e-01 cost_prev=1.1491 cost_new=1.1436
INFO     |  step #6: cost=0.9509 lambd=0.0000 inexact_tol=2.1e-03
INFO     |      - anchor_cost(2): 0.00148 (avg 0.00037)
INFO     |      - augmented_obstacle_constraint(40): 0.19274 (avg 0.00482)
INFO     |      - smoothness_cost(18): 0.94939 (avg 0.02637)
INFO     |      accepted=True ATb_norm=1.86e-02 cost_prev=1.1436 cost_new=1.1435
INFO     |  step #7: cost=0.9518 lambd=0.0000 inexact_tol=2.1e-03
INFO     |      - anchor_cost(2): 0.00148 (avg 0.00037)
INFO     |      - augmented_obstacle_constraint(40): 0.19162 (avg 0.00479)
INFO     |      - smoothness_cost(18): 0.95036 (avg 0.02640)
INFO     |      accepted=True ATb_norm=4.88e-03 cost_prev=1.1435 cost_new=1.1435
INFO     |  AL update: snorm=2.4076e-02, csupn=2.4076e-02, max_rho=1.0706e+01
INFO     |  step #8: cost=0.9522 lambd=0.0000 inexact_tol=2.1e-03
INFO     |      - anchor_cost(2): 0.00148 (avg 0.00037)
INFO     |      - augmented_obstacle_constraint(40): 0.30073 (avg 0.00752)
INFO     |      - smoothness_cost(18): 0.95070 (avg 0.02641)
INFO     |      accepted=True ATb_norm=3.96e-01 cost_prev=1.2529 cost_new=1.2444
INFO     |  step #9: cost=1.0277 lambd=0.0000 inexact_tol=2.1e-03
INFO     |      - anchor_cost(2): 0.00161 (avg 0.00040)
INFO     |      - augmented_obstacle_constraint(40): 0.21676 (avg 0.00542)
INFO     |      - smoothness_cost(18): 1.02608 (avg 0.02850)
INFO     |      accepted=True ATb_norm=5.07e-03 cost_prev=1.2444 cost_new=1.2444
INFO     |  AL update: snorm=6.1835e-03, csupn=6.1835e-03, max_rho=1.0706e+01
INFO     |  step #10: cost=1.0277 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00161 (avg 0.00040)
INFO     |      - augmented_obstacle_constraint(40): 0.24424 (avg 0.00611)
INFO     |      - smoothness_cost(18): 1.02605 (avg 0.02850)
INFO     |      accepted=True ATb_norm=1.00e-01 cost_prev=1.2719 cost_new=1.2712
INFO     |  step #11: cost=1.0498 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00164 (avg 0.00041)
INFO     |      - augmented_obstacle_constraint(40): 0.22144 (avg 0.00554)
INFO     |      - smoothness_cost(18): 1.04813 (avg 0.02911)
INFO     |      accepted=False ATb_norm=5.79e-04 cost_prev=1.2712 cost_new=1.2712
INFO     |  AL update: snorm=1.0778e-03, csupn=1.0778e-03, max_rho=1.0706e+01
INFO     |  step #12: cost=1.0498 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00164 (avg 0.00041)
INFO     |      - augmented_obstacle_constraint(40): 0.22616 (avg 0.00565)
INFO     |      - smoothness_cost(18): 1.04813 (avg 0.02911)
INFO     |      accepted=True ATb_norm=1.75e-02 cost_prev=1.2759 cost_new=1.2759
INFO     |  step #13: cost=1.0536 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00165 (avg 0.00041)
INFO     |      - augmented_obstacle_constraint(40): 0.22226 (avg 0.00556)
INFO     |      - smoothness_cost(18): 1.05200 (avg 0.02922)
INFO     |      accepted=False ATb_norm=1.70e-04 cost_prev=1.2759 cost_new=1.2759
INFO     |  AL update: snorm=1.8895e-04, csupn=1.8895e-04, max_rho=1.0706e+01
INFO     |  step #14: cost=1.0536 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00165 (avg 0.00041)
INFO     |      - augmented_obstacle_constraint(40): 0.22308 (avg 0.00558)
INFO     |      - smoothness_cost(18): 1.05200 (avg 0.02922)
INFO     |      accepted=True ATb_norm=3.07e-03 cost_prev=1.2767 cost_new=1.2767
INFO     |  AL update: snorm=3.4213e-05, csupn=3.4213e-05, max_rho=1.0706e+01
INFO     |  step #15: cost=1.0543 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00165 (avg 0.00041)
INFO     |      - augmented_obstacle_constraint(40): 0.22255 (avg 0.00556)
INFO     |      - smoothness_cost(18): 1.05267 (avg 0.02924)
INFO     |  step #16: cost=1.0543 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00165 (avg 0.00041)
INFO     |      - augmented_obstacle_constraint(40): 0.22255 (avg 0.00556)
INFO     |      - smoothness_cost(18): 1.05267 (avg 0.02924)
INFO     |  step #17: cost=1.0543 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00165 (avg 0.00041)
INFO     |      - augmented_obstacle_constraint(40): 0.22255 (avg 0.00556)
INFO     |      - smoothness_cost(18): 1.05267 (avg 0.02924)
INFO     |  step #18: cost=1.0543 lambd=0.0001 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00165 (avg 0.00041)
INFO     |      - augmented_obstacle_constraint(40): 0.22255 (avg 0.00556)
INFO     |      - smoothness_cost(18): 1.05267 (avg 0.02924)
INFO     |  step #19: cost=1.0543 lambd=0.0002 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00165 (avg 0.00041)
INFO     |      - augmented_obstacle_constraint(40): 0.22255 (avg 0.00556)
INFO     |      - smoothness_cost(18): 1.05267 (avg 0.02924)
INFO     |  step #20: cost=1.0543 lambd=0.0003 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00165 (avg 0.00041)
INFO     |      - augmented_obstacle_constraint(40): 0.22255 (avg 0.00556)
INFO     |      - smoothness_cost(18): 1.05267 (avg 0.02924)
INFO     |      accepted=True ATb_norm=5.45e-04 cost_prev=1.2769 cost_new=1.2769
INFO     |  AL update: snorm=7.9870e-06, csupn=7.9870e-06, max_rho=1.0706e+01
INFO     |  step #21: cost=1.0544 lambd=0.0002 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00165 (avg 0.00041)
INFO     |      - augmented_obstacle_constraint(40): 0.22246 (avg 0.00556)
INFO     |      - smoothness_cost(18): 1.05279 (avg 0.02924)
INFO     |      accepted=True ATb_norm=2.60e-04 cost_prev=1.2769 cost_new=1.2769
INFO     |  AL update: snorm=7.8678e-06, csupn=7.8678e-06, max_rho=4.2824e+01
INFO     |  step #22: cost=1.0544 lambd=0.0001 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00165 (avg 0.00041)
INFO     |      - augmented_obstacle_constraint(40): 0.13904 (avg 0.00348)
INFO     |      - smoothness_cost(18): 1.05280 (avg 0.02924)
INFO     |      accepted=True ATb_norm=9.80e-04 cost_prev=1.1935 cost_new=1.1935
INFO     |  AL update: snorm=7.9870e-06, csupn=7.9870e-06, max_rho=4.2824e+01
INFO     |  step #23: cost=1.0544 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00165 (avg 0.00041)
INFO     |      - augmented_obstacle_constraint(40): 0.05563 (avg 0.00139)
INFO     |      - smoothness_cost(18): 1.05277 (avg 0.02924)
INFO     |      accepted=True ATb_norm=4.92e-03 cost_prev=1.1100 cost_new=1.1085
INFO     |  AL update: snorm=3.7730e-04, csupn=0.0000e+00, max_rho=4.2824e+01
INFO     |  step #24: cost=1.0543 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00165 (avg 0.00041)
INFO     |      - augmented_obstacle_constraint(40): 0.05282 (avg 0.00132)
INFO     |      - smoothness_cost(18): 1.05268 (avg 0.02924)
INFO     |      accepted=True ATb_norm=5.30e-02 cost_prev=1.1071 cost_new=1.0764
INFO     |  step #25: cost=1.0434 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00163 (avg 0.00041)
INFO     |      - augmented_obstacle_constraint(40): 0.03300 (avg 0.00082)
INFO     |      - smoothness_cost(18): 1.04181 (avg 0.02894)
INFO     |      accepted=True ATb_norm=1.07e+00 cost_prev=1.0764 cost_new=1.0461
INFO     |  step #26: cost=0.9435 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00147 (avg 0.00037)
INFO     |      - augmented_obstacle_constraint(40): 0.10259 (avg 0.00256)
INFO     |      - smoothness_cost(18): 0.94204 (avg 0.02617)
INFO     |      accepted=True ATb_norm=1.47e+00 cost_prev=1.0461 cost_new=1.0058
INFO     |  step #27: cost=0.9217 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00149 (avg 0.00037)
INFO     |      - augmented_obstacle_constraint(40): 0.08418 (avg 0.00210)
INFO     |      - smoothness_cost(18): 0.92017 (avg 0.02556)
INFO     |      accepted=True ATb_norm=9.79e-02 cost_prev=1.0058 cost_new=1.0050
INFO     |  step #28: cost=0.9221 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00153 (avg 0.00038)
INFO     |      - augmented_obstacle_constraint(40): 0.08283 (avg 0.00207)
INFO     |      - smoothness_cost(18): 0.92059 (avg 0.02557)
INFO     |      accepted=True ATb_norm=2.17e-02 cost_prev=1.0050 cost_new=1.0049
INFO     |  step #29: cost=0.9211 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00153 (avg 0.00038)
INFO     |      - augmented_obstacle_constraint(40): 0.08382 (avg 0.00210)
INFO     |      - smoothness_cost(18): 0.91957 (avg 0.02554)
INFO     |      accepted=True ATb_norm=2.53e-03 cost_prev=1.0049 cost_new=1.0049
INFO     |  AL update: snorm=5.8322e-02, csupn=5.8322e-02, max_rho=4.2824e+01
INFO     |  step #30: cost=0.9206 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00153 (avg 0.00038)
INFO     |      - augmented_obstacle_constraint(40): 0.47436 (avg 0.01186)
INFO     |      - smoothness_cost(18): 0.91904 (avg 0.02553)
INFO     |      accepted=True ATb_norm=3.92e+00 cost_prev=1.3949 cost_new=1.1356
INFO     |  step #31: cost=1.0855 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00191 (avg 0.00048)
INFO     |      - augmented_obstacle_constraint(40): 0.05005 (avg 0.00125)
INFO     |      - smoothness_cost(18): 1.08361 (avg 0.03010)
INFO     |      accepted=True ATb_norm=2.99e-01 cost_prev=1.1356 cost_new=1.0921
INFO     |  step #32: cost=1.0651 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00184 (avg 0.00046)
INFO     |      - augmented_obstacle_constraint(40): 0.02691 (avg 0.00067)
INFO     |      - smoothness_cost(18): 1.06331 (avg 0.02954)
INFO     |  step #33: cost=1.0651 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00184 (avg 0.00046)
INFO     |      - augmented_obstacle_constraint(40): 0.02691 (avg 0.00067)
INFO     |      - smoothness_cost(18): 1.06331 (avg 0.02954)
INFO     |  step #34: cost=1.0651 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00184 (avg 0.00046)
INFO     |      - augmented_obstacle_constraint(40): 0.02691 (avg 0.00067)
INFO     |      - smoothness_cost(18): 1.06331 (avg 0.02954)
INFO     |  step #35: cost=1.0651 lambd=0.0001 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00184 (avg 0.00046)
INFO     |      - augmented_obstacle_constraint(40): 0.02691 (avg 0.00067)
INFO     |      - smoothness_cost(18): 1.06331 (avg 0.02954)
INFO     |  step #36: cost=1.0651 lambd=0.0002 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00184 (avg 0.00046)
INFO     |      - augmented_obstacle_constraint(40): 0.02691 (avg 0.00067)
INFO     |      - smoothness_cost(18): 1.06331 (avg 0.02954)
INFO     |  step #37: cost=1.0651 lambd=0.0003 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00184 (avg 0.00046)
INFO     |      - augmented_obstacle_constraint(40): 0.02691 (avg 0.00067)
INFO     |      - smoothness_cost(18): 1.06331 (avg 0.02954)
INFO     |  step #38: cost=1.0651 lambd=0.0006 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00184 (avg 0.00046)
INFO     |      - augmented_obstacle_constraint(40): 0.02691 (avg 0.00067)
INFO     |      - smoothness_cost(18): 1.06331 (avg 0.02954)
INFO     |  step #39: cost=1.0651 lambd=0.0013 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00184 (avg 0.00046)
INFO     |      - augmented_obstacle_constraint(40): 0.02691 (avg 0.00067)
INFO     |      - smoothness_cost(18): 1.06331 (avg 0.02954)
INFO     |  step #40: cost=1.0651 lambd=0.0026 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00184 (avg 0.00046)
INFO     |      - augmented_obstacle_constraint(40): 0.02691 (avg 0.00067)
INFO     |      - smoothness_cost(18): 1.06331 (avg 0.02954)
INFO     |  step #41: cost=1.0651 lambd=0.0051 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00184 (avg 0.00046)
INFO     |      - augmented_obstacle_constraint(40): 0.02691 (avg 0.00067)
INFO     |      - smoothness_cost(18): 1.06331 (avg 0.02954)
INFO     |  step #42: cost=1.0651 lambd=0.0102 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00184 (avg 0.00046)
INFO     |      - augmented_obstacle_constraint(40): 0.02691 (avg 0.00067)
INFO     |      - smoothness_cost(18): 1.06331 (avg 0.02954)
INFO     |  step #43: cost=1.0651 lambd=0.0205 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00184 (avg 0.00046)
INFO     |      - augmented_obstacle_constraint(40): 0.02691 (avg 0.00067)
INFO     |      - smoothness_cost(18): 1.06331 (avg 0.02954)
INFO     |  step #44: cost=1.0651 lambd=0.0410 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00184 (avg 0.00046)
INFO     |      - augmented_obstacle_constraint(40): 0.02691 (avg 0.00067)
INFO     |      - smoothness_cost(18): 1.06331 (avg 0.02954)
INFO     |      accepted=True ATb_norm=4.05e-01 cost_prev=1.0921 cost_new=1.0892
INFO     |  step #45: cost=1.0003 lambd=0.0205 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00126 (avg 0.00031)
INFO     |      - augmented_obstacle_constraint(40): 0.08890 (avg 0.00222)
INFO     |      - smoothness_cost(18): 0.99903 (avg 0.02775)
INFO     |      accepted=True ATb_norm=1.76e+00 cost_prev=1.0892 cost_new=1.0462
INFO     |  step #46: cost=1.0062 lambd=0.0102 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00158 (avg 0.00039)
INFO     |      - augmented_obstacle_constraint(40): 0.04001 (avg 0.00100)
INFO     |      - smoothness_cost(18): 1.00464 (avg 0.02791)
INFO     |      accepted=True ATb_norm=4.79e-02 cost_prev=1.0462 cost_new=1.0443
INFO     |  step #47: cost=1.0014 lambd=0.0051 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00153 (avg 0.00038)
INFO     |      - augmented_obstacle_constraint(40): 0.04295 (avg 0.00107)
INFO     |      - smoothness_cost(18): 0.99985 (avg 0.02777)
INFO     |      accepted=True ATb_norm=3.50e-02 cost_prev=1.0443 cost_new=1.0437
INFO     |  step #48: cost=0.9985 lambd=0.0026 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00152 (avg 0.00038)
INFO     |      - augmented_obstacle_constraint(40): 0.04516 (avg 0.00113)
INFO     |      - smoothness_cost(18): 0.99699 (avg 0.02769)
INFO     |      accepted=True ATb_norm=1.30e-02 cost_prev=1.0437 cost_new=1.0434
INFO     |  step #49: cost=0.9969 lambd=0.0013 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00151 (avg 0.00038)
INFO     |      - augmented_obstacle_constraint(40): 0.04650 (avg 0.00116)
INFO     |      - smoothness_cost(18): 0.99543 (avg 0.02765)
INFO     |      accepted=True ATb_norm=1.22e-02 cost_prev=1.0434 cost_new=1.0431
INFO     |  step #50: cost=0.9972 lambd=0.0006 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00153 (avg 0.00038)
INFO     |      - augmented_obstacle_constraint(40): 0.04593 (avg 0.00115)
INFO     |      - smoothness_cost(18): 0.99568 (avg 0.02766)
INFO     |      accepted=True ATb_norm=1.23e-02 cost_prev=1.0431 cost_new=1.0428
INFO     |  step #51: cost=0.9968 lambd=0.0003 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00152 (avg 0.00038)
INFO     |      - augmented_obstacle_constraint(40): 0.04598 (avg 0.00115)
INFO     |      - smoothness_cost(18): 0.99526 (avg 0.02765)
INFO     |      accepted=True ATb_norm=1.28e-02 cost_prev=1.0428 cost_new=1.0423
INFO     |  step #52: cost=0.9961 lambd=0.0002 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00153 (avg 0.00038)
INFO     |      - augmented_obstacle_constraint(40): 0.04618 (avg 0.00115)
INFO     |      - smoothness_cost(18): 0.99462 (avg 0.02763)
INFO     |      accepted=True ATb_norm=1.37e-02 cost_prev=1.0423 cost_new=1.0419
INFO     |  step #53: cost=0.9953 lambd=0.0001 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00153 (avg 0.00038)
INFO     |      - augmented_obstacle_constraint(40): 0.04655 (avg 0.00116)
INFO     |      - smoothness_cost(18): 0.99378 (avg 0.02760)
INFO     |      accepted=True ATb_norm=1.43e-02 cost_prev=1.0419 cost_new=1.0414
INFO     |  step #54: cost=0.9943 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00153 (avg 0.00038)
INFO     |      - augmented_obstacle_constraint(40): 0.04709 (avg 0.00118)
INFO     |      - smoothness_cost(18): 0.99275 (avg 0.02758)
INFO     |      accepted=True ATb_norm=1.45e-02 cost_prev=1.0414 cost_new=1.0409
INFO     |  step #55: cost=0.9931 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00154 (avg 0.00038)
INFO     |      - augmented_obstacle_constraint(40): 0.04778 (avg 0.00119)
INFO     |      - smoothness_cost(18): 0.99160 (avg 0.02754)
INFO     |      accepted=True ATb_norm=1.41e-02 cost_prev=1.0409 cost_new=1.0405
INFO     |  step #56: cost=0.9919 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00154 (avg 0.00038)
INFO     |      - augmented_obstacle_constraint(40): 0.04859 (avg 0.00121)
INFO     |      - smoothness_cost(18): 0.99037 (avg 0.02751)
INFO     |      accepted=True ATb_norm=1.29e-02 cost_prev=1.0405 cost_new=1.0402
INFO     |  step #57: cost=0.9907 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00154 (avg 0.00038)
INFO     |      - augmented_obstacle_constraint(40): 0.04946 (avg 0.00124)
INFO     |      - smoothness_cost(18): 0.98916 (avg 0.02748)
INFO     |      accepted=True ATb_norm=1.18e-02 cost_prev=1.0402 cost_new=1.0399
INFO     |  step #58: cost=0.9895 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00154 (avg 0.00039)
INFO     |      - augmented_obstacle_constraint(40): 0.05034 (avg 0.00126)
INFO     |      - smoothness_cost(18): 0.98801 (avg 0.02744)
INFO     |      accepted=True ATb_norm=1.03e-02 cost_prev=1.0399 cost_new=1.0397
INFO     |  step #59: cost=0.9885 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00154 (avg 0.00039)
INFO     |      - augmented_obstacle_constraint(40): 0.05118 (avg 0.00128)
INFO     |      - smoothness_cost(18): 0.98697 (avg 0.02742)
INFO     |      accepted=True ATb_norm=8.80e-03 cost_prev=1.0397 cost_new=1.0396
INFO     |  AL update: snorm=1.9573e-02, csupn=1.9573e-02, max_rho=1.7129e+02
INFO     |  step #60: cost=0.9876 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00154 (avg 0.00039)
INFO     |      - augmented_obstacle_constraint(40): 0.16703 (avg 0.00418)
INFO     |      - smoothness_cost(18): 0.98607 (avg 0.02739)
INFO     |      accepted=True ATb_norm=4.18e+00 cost_prev=1.1546 cost_new=1.0663
INFO     |  step #61: cost=1.0532 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00169 (avg 0.00042)
INFO     |      - augmented_obstacle_constraint(40): 0.01316 (avg 0.00033)
INFO     |      - smoothness_cost(18): 1.05146 (avg 0.02921)
INFO     |      accepted=True ATb_norm=6.24e-02 cost_prev=1.0663 cost_new=1.0615
INFO     |  step #62: cost=1.0455 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #63: cost=1.0455 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #64: cost=1.0455 lambd=0.0000 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #65: cost=1.0455 lambd=0.0001 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #66: cost=1.0455 lambd=0.0002 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #67: cost=1.0455 lambd=0.0003 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #68: cost=1.0455 lambd=0.0006 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #69: cost=1.0455 lambd=0.0013 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #70: cost=1.0455 lambd=0.0026 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #71: cost=1.0455 lambd=0.0051 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #72: cost=1.0455 lambd=0.0102 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #73: cost=1.0455 lambd=0.0205 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #74: cost=1.0455 lambd=0.0410 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #75: cost=1.0455 lambd=0.0819 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #76: cost=1.0455 lambd=0.1638 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #77: cost=1.0455 lambd=0.3277 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #78: cost=1.0455 lambd=0.6554 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #79: cost=1.0455 lambd=1.3107 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #80: cost=1.0455 lambd=2.6214 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #81: cost=1.0455 lambd=5.2429 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #82: cost=1.0455 lambd=10.4858 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #83: cost=1.0455 lambd=20.9715 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |  step #84: cost=1.0455 lambd=41.9430 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00171 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.01597 (avg 0.00040)
INFO     |      - smoothness_cost(18): 1.04382 (avg 0.02900)
INFO     |      accepted=True ATb_norm=1.73e+00 cost_prev=1.0615 cost_new=1.0396
INFO     |  step #85: cost=1.0023 lambd=20.9715 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00163 (avg 0.00041)
INFO     |      - augmented_obstacle_constraint(40): 0.03732 (avg 0.00093)
INFO     |      - smoothness_cost(18): 1.00065 (avg 0.02780)
INFO     |      accepted=True ATb_norm=1.42e+00 cost_prev=1.0396 cost_new=1.0322
INFO     |  step #86: cost=1.0073 lambd=10.4858 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00166 (avg 0.00041)
INFO     |      - augmented_obstacle_constraint(40): 0.02494 (avg 0.00062)
INFO     |      - smoothness_cost(18): 1.00562 (avg 0.02793)
INFO     |      accepted=True ATb_norm=1.63e-01 cost_prev=1.0322 cost_new=1.0302
INFO     |  step #87: cost=1.0068 lambd=5.2429 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00173 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.02337 (avg 0.00058)
INFO     |      - smoothness_cost(18): 1.00507 (avg 0.02792)
INFO     |      accepted=True ATb_norm=9.38e-02 cost_prev=1.0302 cost_new=1.0282
INFO     |  step #88: cost=1.0053 lambd=2.6214 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00173 (avg 0.00043)
INFO     |      - augmented_obstacle_constraint(40): 0.02288 (avg 0.00057)
INFO     |      - smoothness_cost(18): 1.00359 (avg 0.02788)
INFO     |      accepted=True ATb_norm=6.84e-02 cost_prev=1.0282 cost_new=1.0255
INFO     |  step #89: cost=1.0026 lambd=1.3107 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00166 (avg 0.00042)
INFO     |      - augmented_obstacle_constraint(40): 0.02297 (avg 0.00057)
INFO     |      - smoothness_cost(18): 1.00092 (avg 0.02780)
INFO     |      accepted=True ATb_norm=5.94e-02 cost_prev=1.0255 cost_new=1.0212
INFO     |  step #90: cost=0.9973 lambd=0.6554 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00159 (avg 0.00040)
INFO     |      - augmented_obstacle_constraint(40): 0.02389 (avg 0.00060)
INFO     |      - smoothness_cost(18): 0.99572 (avg 0.02766)
INFO     |      accepted=True ATb_norm=7.04e-02 cost_prev=1.0212 cost_new=1.0151
INFO     |  step #91: cost=0.9901 lambd=0.3277 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00154 (avg 0.00039)
INFO     |      - augmented_obstacle_constraint(40): 0.02497 (avg 0.00062)
INFO     |      - smoothness_cost(18): 0.98859 (avg 0.02746)
INFO     |      accepted=True ATb_norm=6.03e-02 cost_prev=1.0151 cost_new=1.0066
INFO     |  step #92: cost=0.9777 lambd=0.1638 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00153 (avg 0.00038)
INFO     |      - augmented_obstacle_constraint(40): 0.02891 (avg 0.00072)
INFO     |      - smoothness_cost(18): 0.97620 (avg 0.02712)
INFO     |      accepted=True ATb_norm=8.89e-02 cost_prev=1.0066 cost_new=0.9964
INFO     |  step #93: cost=0.9591 lambd=0.0819 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00153 (avg 0.00038)
INFO     |      - augmented_obstacle_constraint(40): 0.03738 (avg 0.00093)
INFO     |      - smoothness_cost(18): 0.95755 (avg 0.02660)
INFO     |      accepted=True ATb_norm=1.43e-01 cost_prev=0.9964 cost_new=0.9867
INFO     |  step #94: cost=0.9352 lambd=0.0410 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00152 (avg 0.00038)
INFO     |      - augmented_obstacle_constraint(40): 0.05151 (avg 0.00129)
INFO     |      - smoothness_cost(18): 0.93370 (avg 0.02594)
INFO     |      accepted=True ATb_norm=1.93e-01 cost_prev=0.9867 cost_new=0.9828
INFO     |  step #95: cost=0.9235 lambd=0.0205 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00150 (avg 0.00037)
INFO     |      - augmented_obstacle_constraint(40): 0.05931 (avg 0.00148)
INFO     |      - smoothness_cost(18): 0.92202 (avg 0.02561)
INFO     |      accepted=True ATb_norm=6.50e-02 cost_prev=0.9828 cost_new=0.9815
INFO     |  step #96: cost=0.9162 lambd=0.0102 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00150 (avg 0.00037)
INFO     |      - augmented_obstacle_constraint(40): 0.06536 (avg 0.00163)
INFO     |      - smoothness_cost(18): 0.91469 (avg 0.02541)
INFO     |      accepted=True ATb_norm=2.81e-02 cost_prev=0.9815 cost_new=0.9811
INFO     |  step #97: cost=0.9121 lambd=0.0051 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00150 (avg 0.00037)
INFO     |      - augmented_obstacle_constraint(40): 0.06906 (avg 0.00173)
INFO     |      - smoothness_cost(18): 0.91059 (avg 0.02529)
INFO     |      accepted=True ATb_norm=1.16e-02 cost_prev=0.9811 cost_new=0.9810
INFO     |  step #98: cost=0.9099 lambd=0.0026 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00150 (avg 0.00037)
INFO     |      - augmented_obstacle_constraint(40): 0.07110 (avg 0.00178)
INFO     |      - smoothness_cost(18): 0.90843 (avg 0.02523)
INFO     |      accepted=True ATb_norm=5.14e-03 cost_prev=0.9810 cost_new=0.9810
INFO     |  AL update: snorm=6.0159e-02, csupn=6.0159e-02, max_rho=1.7129e+02
INFO     |  step #99: cost=0.9088 lambd=0.0013 inexact_tol=1.5e-04
INFO     |      - anchor_cost(2): 0.00149 (avg 0.00037)
INFO     |      - augmented_obstacle_constraint(40): 0.43579 (avg 0.01089)
INFO     |      - smoothness_cost(18): 0.90728 (avg 0.02520)
INFO     |      accepted=True ATb_norm=3.75e+00 cost_prev=1.3446 cost_new=1.1252
INFO     | Terminated @ iteration #100: cost=1.0642 criteria=[0 0 0], term_deltas=1.7e-01,2.7e+00,3.7e-03

Visualization#

Hide code cell source

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import HTML


def get_trajectory_traces(
    vals: jaxls.VarValues, name_prefix: str = ""
) -> list[go.Scatter]:
    """Get Plotly traces for trajectory visualization.

    Args:
        vals: Variable values containing waypoint positions
        name_prefix: Prefix for trace names

    Returns:
        List of Plotly Scatter traces for path, start, and goal markers
    """
    traj = vals[waypoint_vars]  # (n_waypoints, 2)
    traj_x = [float(x) for x in traj[:, 0]]
    traj_y = [float(y) for y in traj[:, 1]]

    return [
        # Trajectory line and points.
        go.Scatter(
            x=traj_x,
            y=traj_y,
            mode="lines+markers",
            line=dict(color="steelblue", width=2),
            marker=dict(size=6, color="steelblue"),
            name=f"{name_prefix}Path",
            hovertemplate="(%{x:.2f}, %{y:.2f})<extra></extra>",
        ),
        # Start marker.
        go.Scatter(
            x=[traj_x[0]],
            y=[traj_y[0]],
            mode="markers",
            marker=dict(size=14, color="green", symbol="circle"),
            name=f"{name_prefix}Start",
            hovertemplate="Start (%{x:.2f}, %{y:.2f})<extra></extra>",
        ),
        # Goal marker.
        go.Scatter(
            x=[traj_x[-1]],
            y=[traj_y[-1]],
            mode="markers",
            marker=dict(size=14, color="orange", symbol="square"),
            name=f"{name_prefix}Goal",
            hovertemplate="Goal (%{x:.2f}, %{y:.2f})<extra></extra>",
        ),
    ]


def get_obstacle_shapes() -> list[dict]:
    """Get Plotly shapes for obstacles.

    Returns:
        List of Plotly shape dictionaries for circular obstacles
    """
    return [
        dict(
            type="circle",
            xref="x",
            yref="y",
            x0=cx - r,
            y0=cy - r,
            x1=cx + r,
            y1=cy + r,
            fillcolor="rgba(255, 99, 71, 0.3)",
            line=dict(color="tomato", width=2),
        )
        for cx, cy, r in obstacles
    ]


fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Initial (straight line)", "Optimized (avoids obstacles)"),
)

# Add trajectory traces.
for trace in get_trajectory_traces(initial_vals):
    fig.add_trace(trace, row=1, col=1)
for trace in get_trajectory_traces(solution):
    fig.add_trace(trace, row=1, col=2)

# Add obstacle shapes to both subplots.
obstacle_shapes = get_obstacle_shapes()
for shape in obstacle_shapes:
    fig.add_shape({**shape, "xref": "x", "yref": "y"}, row=1, col=1)
    fig.add_shape({**shape, "xref": "x2", "yref": "y2"}, row=1, col=2)

# Consistent axis ranges for both subplots.
fig.update_xaxes(title_text="x", range=[-1, 11], row=1, col=1)
fig.update_xaxes(title_text="x", range=[-1, 11], row=1, col=2)
fig.update_yaxes(
    title_text="y", range=[-4, 4], scaleanchor="x", scaleratio=1, row=1, col=1
)
fig.update_yaxes(
    title_text="y", range=[-4, 4], scaleanchor="x2", scaleratio=1, row=1, col=2
)
fig.update_layout(height=400, showlegend=False, margin=dict(t=40, b=40, l=40, r=40))
HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

The optimizer found a smooth trajectory that avoids both obstacles while connecting start to goal. The inequality constraints keep waypoints outside the obstacle regions.

For solver configuration options, see jaxls.TrustRegionConfig and jaxls.AugmentedLagrangianConfig.