# Obstacle avoidance

In this notebook, we solve an obstacle avoidance problem: finding a smooth path from
start to goal while avoiding circular obstacles.

Features used:
- {class}`~jaxls.Var` for 2D waypoint positions
- {func}`@jaxls.Cost.factory <jaxls.Cost.factory>` for smoothness and anchor costs
- Inequality constraints (`constraint_leq_zero`): obstacle avoidance
- Augmented Lagrangian solver for constrained optimization

In [None]:
import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");

In [2]:
import jax
import jax.numpy as jnp
import jaxls

## Problem setup

Plan a trajectory with 20 waypoints from start to goal, avoiding two obstacles:

In [3]:
# Trajectory parameters.
n_waypoints = 20
start = jnp.array([0.0, 0.0])
goal = jnp.array([10.0, 0.0])

# Obstacles: (center_x, center_y, radius)
obstacles = [
    (3.5, 1.0, 1.5),
    (6.5, -1.0, 1.5),
]

## Variables and costs

Variable instances are [PyTrees](https://docs.jax.dev/en/latest/pytrees.html). We define waypoints and cost functions:

In [4]:
class WaypointVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(2)):
    """2D waypoint position."""


# Batched variable creation.
waypoint_vars = WaypointVar(id=jnp.arange(n_waypoints))

In [5]:
@jaxls.Cost.factory
def smoothness_cost(
    vals: jaxls.VarValues,
    var_prev: WaypointVar,
    var_curr: WaypointVar,
    var_next: WaypointVar,
) -> jax.Array:
    """Penalize acceleration (second derivative)."""
    accel = vals[var_prev] - 2 * vals[var_curr] + vals[var_next]
    return accel * 5.0  # Weight for smoothness


@jaxls.Cost.factory
def anchor_cost(
    vals: jaxls.VarValues,
    var: WaypointVar,
    target: jax.Array,
) -> jax.Array:
    """Pin waypoint to target position."""
    return (vals[var] - target) * 10.0  # Strong weight


@jaxls.Cost.factory(kind="constraint_leq_zero")
def obstacle_constraint(
    vals: jaxls.VarValues,
    var: WaypointVar,
    center: jax.Array,
    radius: float,
) -> jax.Array:
    """Stay outside obstacle: ||p - center|| >= radius."""
    diff = vals[var] - center
    dist = jnp.sqrt(jnp.sum(diff**2) + 1e-6)
    # Constraint: radius - dist <= 0 (i.e., dist >= radius)
    return jnp.array([radius - dist])

## Problem construction

In [6]:
# Indices for smoothness costs (interior waypoints)
smooth_prev = jnp.arange(0, n_waypoints - 2)  # 0, 1, ..., n-3
smooth_curr = jnp.arange(1, n_waypoints - 1)  # 1, 2, ..., n-2
smooth_next = jnp.arange(2, n_waypoints)  # 2, 3, ..., n-1

# Obstacle constraint: each waypoint x each obstacle.
n_obstacles = len(obstacles)
waypoint_ids_for_obs = jnp.tile(jnp.arange(n_waypoints), n_obstacles)
obstacle_centers = jnp.array([[cx, cy] for cx, cy, r in obstacles])
obstacle_radii = jnp.array([r for cx, cy, r in obstacles])
# Repeat each obstacle for all waypoints.
centers_repeated = jnp.repeat(obstacle_centers, n_waypoints, axis=0)
radii_repeated = jnp.repeat(obstacle_radii, n_waypoints)

# Build costs using batched construction.
costs: list[jaxls.Cost] = [
    # Anchor start and goal.
    anchor_cost(WaypointVar(id=0), start),
    anchor_cost(WaypointVar(id=n_waypoints - 1), goal),
    # Smoothness costs (batched)
    smoothness_cost(
        WaypointVar(id=smooth_prev),
        WaypointVar(id=smooth_curr),
        WaypointVar(id=smooth_next),
    ),
    # Obstacle avoidance constraints (batched)
    obstacle_constraint(
        WaypointVar(id=waypoint_ids_for_obs),
        centers_repeated,
        radii_repeated,
    ),
]

print(f"Created {len(costs)} batched cost objects")

Created 4 batched cost objects


In [None]:
# Initialize with straight line (will collide with obstacles)
t = jnp.linspace(0, 1, n_waypoints)
initial_positions = start + t[:, None] * (goal - start)  # (n_waypoints, 2)

initial_vals = jaxls.VarValues.make([waypoint_vars.with_value(initial_positions)])

# Create the problem.
problem = jaxls.LeastSquaresProblem(costs, [waypoint_vars])

# Visualize the problem structure structure.
problem.show()

In [None]:
# Analyze the problem.
problem = problem.analyze()

## Solving

In [8]:
solution = problem.solve(initial_vals)

[1mINFO    [0m | Augmented Lagrangian: initial snorm=4.8317e-01, csupn=4.8317e-01, max_rho=1.0706e+01, constraint_dim=40
[1mINFO    [0m |  step #0: cost=0.0000 lambd=0.0005 inexact_tol=1.0e-02
[1mINFO    [0m |      - anchor_cost(2): 0.00000 (avg 0.00000)
[1mINFO    [0m |      - smoothness_cost(18): 0.00000 (avg 0.00000)
[1mINFO    [0m |      - augmented_obstacle_constraint(40): 11.46163 (avg 0.28654)
[1mINFO    [0m |      accepted=True ATb_norm=1.19e+01 cost_prev=11.4616 cost_new=0.9302
[1mINFO    [0m |  step #1: cost=0.8331 lambd=0.0003 inexact_tol=1.0e-02
[1mINFO    [0m |      - anchor_cost(2): 0.00119 (avg 0.00030)
[1mINFO    [0m |      - smoothness_cost(18): 0.83189 (avg 0.02311)
[1mINFO    [0m |      - augmented_obstacle_constraint(40): 0.09714 (avg 0.00243)
[1mINFO    [0m |      accepted=True ATb_norm=9.94e-01 cost_prev=0.9302 cost_new=0.8834
[1mINFO    [0m |  step #2: cost=0.7775 lambd=0.0001 inexact_tol=6.3e-03
[1mINFO    [0m |      - anchor_cost(2): 0

## Visualization

In [9]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import HTML


def get_trajectory_traces(
    vals: jaxls.VarValues, name_prefix: str = ""
) -> list[go.Scatter]:
    """Get Plotly traces for trajectory visualization.

    Args:
        vals: Variable values containing waypoint positions
        name_prefix: Prefix for trace names

    Returns:
        List of Plotly Scatter traces for path, start, and goal markers
    """
    traj = vals[waypoint_vars]  # (n_waypoints, 2)
    traj_x = [float(x) for x in traj[:, 0]]
    traj_y = [float(y) for y in traj[:, 1]]

    return [
        # Trajectory line and points.
        go.Scatter(
            x=traj_x,
            y=traj_y,
            mode="lines+markers",
            line=dict(color="steelblue", width=2),
            marker=dict(size=6, color="steelblue"),
            name=f"{name_prefix}Path",
            hovertemplate="(%{x:.2f}, %{y:.2f})<extra></extra>",
        ),
        # Start marker.
        go.Scatter(
            x=[traj_x[0]],
            y=[traj_y[0]],
            mode="markers",
            marker=dict(size=14, color="green", symbol="circle"),
            name=f"{name_prefix}Start",
            hovertemplate="Start (%{x:.2f}, %{y:.2f})<extra></extra>",
        ),
        # Goal marker.
        go.Scatter(
            x=[traj_x[-1]],
            y=[traj_y[-1]],
            mode="markers",
            marker=dict(size=14, color="orange", symbol="square"),
            name=f"{name_prefix}Goal",
            hovertemplate="Goal (%{x:.2f}, %{y:.2f})<extra></extra>",
        ),
    ]


def get_obstacle_shapes() -> list[dict]:
    """Get Plotly shapes for obstacles.

    Returns:
        List of Plotly shape dictionaries for circular obstacles
    """
    return [
        dict(
            type="circle",
            xref="x",
            yref="y",
            x0=cx - r,
            y0=cy - r,
            x1=cx + r,
            y1=cy + r,
            fillcolor="rgba(255, 99, 71, 0.3)",
            line=dict(color="tomato", width=2),
        )
        for cx, cy, r in obstacles
    ]


fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Initial (straight line)", "Optimized (avoids obstacles)"),
)

# Add trajectory traces.
for trace in get_trajectory_traces(initial_vals):
    fig.add_trace(trace, row=1, col=1)
for trace in get_trajectory_traces(solution):
    fig.add_trace(trace, row=1, col=2)

# Add obstacle shapes to both subplots.
obstacle_shapes = get_obstacle_shapes()
for shape in obstacle_shapes:
    fig.add_shape({**shape, "xref": "x", "yref": "y"}, row=1, col=1)
    fig.add_shape({**shape, "xref": "x2", "yref": "y2"}, row=1, col=2)

# Consistent axis ranges for both subplots.
fig.update_xaxes(title_text="x", range=[-1, 11], row=1, col=1)
fig.update_xaxes(title_text="x", range=[-1, 11], row=1, col=2)
fig.update_yaxes(
    title_text="y", range=[-4, 4], scaleanchor="x", scaleratio=1, row=1, col=1
)
fig.update_yaxes(
    title_text="y", range=[-4, 4], scaleanchor="x2", scaleratio=1, row=1, col=2
)
fig.update_layout(height=400, showlegend=False, margin=dict(t=40, b=40, l=40, r=40))
HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

The optimizer found a smooth trajectory that avoids both obstacles while connecting start to goal. The inequality constraints keep waypoints outside the obstacle regions.

For solver configuration options, see {class}`jaxls.TrustRegionConfig` and {class}`jaxls.AugmentedLagrangianConfig`.