Cart-pole (collocation)#

In this notebook, we solve a cart-pole swing-up problem using direct collocation: an alternative to single shooting that optimizes the full state trajectory simultaneously.

Features used:

  • Var subclassing with batched IDs for state and control trajectories

  • @jaxls.Cost.factory with batched arguments

  • Equality constraints (constraint_eq_zero): dynamics collocation

  • Inequality constraints (constraint_leq_zero): control bounds

  • Augmented Lagrangian solver for constrained optimization

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import jax
import jax.numpy as jnp
import jaxls

Cart-pole dynamics#

The cart-pole system has state \([x, \theta, \dot{x}, \dot{\theta}]\) where \(x\) is cart position and \(\theta\) is pole angle (0 = hanging down, \(\pi\) = upright). The control input is a horizontal force on the cart.

# Physical parameters.
m_cart = 1.0  # Cart mass (kg)
m_pole = 0.1  # Pole mass (kg)
length = 0.5  # Pole half-length (m)
gravity = 9.81  # Gravitational acceleration (m/s^2)
force_limit = 10.0  # Maximum control force (N)

# Trajectory parameters.
n_swing = 50  # Timesteps for swing-up phase
n_hold = 10  # Timesteps to hold upright position
n_steps = n_swing + n_hold  # Total timesteps
dt = 0.05  # Time step (s)
total_time = n_steps * dt
print(f"Total trajectory time: {total_time:.2f}s ({n_swing} swing + {n_hold} hold)")
Total trajectory time: 3.00s (50 swing + 10 hold)
@jax.jit
def cart_pole_dynamics(state: jax.Array, force: jax.Array) -> jax.Array:
    """Compute state derivatives for the cart-pole system.

    State: [x, theta, x_dot, theta_dot]
    theta = 0: pole hanging down, theta = pi: pole upright

    Note: We use the standard cart-pole equations but with theta measured
    from hanging down, so the gravitational term has a negative sign.
    """
    x, theta, x_dot, theta_dot = state
    f = force[0]

    sin_th = jnp.sin(theta)
    cos_th = jnp.cos(theta)

    # Total mass.
    total_mass = m_cart + m_pole
    pole_mass_length = m_pole * length

    # Equations of motion for theta=0 being DOWN (stable equilibrium)
    # The gravitational term is negated compared to theta=0 being UP.
    temp = (f + pole_mass_length * theta_dot**2 * sin_th) / total_mass
    theta_ddot = (-gravity * sin_th - cos_th * temp) / (
        length * (4.0 / 3.0 - m_pole * cos_th**2 / total_mass)
    )
    x_ddot = temp - pole_mass_length * theta_ddot * cos_th / total_mass

    return jnp.array([x_dot, theta_dot, x_ddot, theta_ddot])


@jax.jit
def rk4_step(state: jax.Array, force: jax.Array, dt: float) -> jax.Array:
    """Runge-Kutta 4th order integration step."""
    k1 = cart_pole_dynamics(state, force)
    k2 = cart_pole_dynamics(state + 0.5 * dt * k1, force)
    k3 = cart_pole_dynamics(state + 0.5 * dt * k2, force)
    k4 = cart_pole_dynamics(state + dt * k3, force)
    return state + (dt / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4)

Variables#

We define state variables at each timestep and control inputs between timesteps. Using batched IDs for efficient construction:

class StateVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(4)):
    """State variable: [x, theta, x_dot, theta_dot]."""


class ControlVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(1)):
    """Control variable: [force]."""


# Create batched variables.
state_vars = StateVar(id=jnp.arange(n_steps + 1))  # States at t=0, 1, ..., n_steps
control_vars = ControlVar(id=jnp.arange(n_steps))  # Controls at t=0, 1, ..., n_steps-1

print(f"State variables: {n_steps + 1} (shape: {state_vars.id.shape})")
print(f"Control variables: {n_steps} (shape: {control_vars.id.shape})")
State variables: 61 (shape: (61,))
Control variables: 60 (shape: (60,))

Cost functions#

We define costs for:

  1. Dynamics constraints: RK4 integration between consecutive states

  2. Boundary costs: Penalize deviation from initial and target states

  3. Control effort: Minimize force usage

  4. Control bounds: Keep force within limits

@jaxls.Cost.factory(kind="constraint_eq_zero")
def dynamics_constraint(
    vals: jaxls.VarValues,
    state_k: StateVar,
    state_k1: StateVar,
    control_k: ControlVar,
    dt: float,
) -> jax.Array:
    """Enforce RK4 dynamics between consecutive states."""
    s_k = vals[state_k]
    s_k1 = vals[state_k1]
    u_k = vals[control_k]
    s_next = rk4_step(s_k, u_k, dt)
    return s_k1 - s_next


@jaxls.Cost.factory
def boundary_cost(
    vals: jaxls.VarValues,
    var: StateVar,
    target: jax.Array,
) -> jax.Array:
    """Penalize deviation from target state."""
    return (vals[var] - target) * 100.0


@jaxls.Cost.factory
def control_cost(
    vals: jaxls.VarValues,
    var: ControlVar,
) -> jax.Array:
    """Minimize control effort."""
    return vals[var] * 1.0


@jaxls.Cost.factory(kind="constraint_leq_zero")
def control_upper_bound(
    vals: jaxls.VarValues,
    var: ControlVar,
    limit: float,
) -> jax.Array:
    """Control <= limit."""
    return vals[var] - limit


@jaxls.Cost.factory(kind="constraint_leq_zero")
def control_lower_bound(
    vals: jaxls.VarValues,
    var: ControlVar,
    limit: float,
) -> jax.Array:
    """-limit <= control, i.e., -limit - control <= 0."""
    return -limit - vals[var]

Problem construction#

Initial state: pole hanging down (\(\theta = 0\)). Target state: pole upright (\(\theta = \pi\)), centered at origin with zero velocity.

# Boundary conditions.
initial_state = jnp.array([0.0, 0.0, 0.0, 0.0])  # x=0, theta=0 (down), zero velocity
target_state = jnp.array([0.0, jnp.pi, 0.0, 0.0])  # x=0, theta=pi (up), zero velocity

# State indices for dynamics constraints.
state_k_ids = jnp.arange(n_steps)  # 0, 1, ..., n_steps-1
state_k1_ids = jnp.arange(1, n_steps + 1)  # 1, 2, ..., n_steps
control_k_ids = jnp.arange(n_steps)  # 0, 1, ..., n_steps-1

# Hold phase: penalize states during hold to stay at target.
hold_state_ids = jnp.arange(n_swing, n_steps + 1)  # States during hold phase
n_hold_states = len(hold_state_ids)
hold_targets = jnp.tile(target_state, (n_hold_states, 1))  # Broadcast for batching

# Build costs using batched construction.
costs: list[jaxls.Cost] = [
    # Dynamics constraints (batched)
    dynamics_constraint(
        StateVar(id=state_k_ids),
        StateVar(id=state_k1_ids),
        ControlVar(id=control_k_ids),
        dt,
    ),
    # Boundary costs.
    boundary_cost(StateVar(id=0), initial_state),
    # Hold phase: penalize all hold states deviating from target.
    boundary_cost(StateVar(id=hold_state_ids), hold_targets),
    # Control effort (batched)
    control_cost(control_vars),
    # Control bounds (batched)
    control_upper_bound(control_vars, force_limit),
    control_lower_bound(control_vars, force_limit),
]

print(f"Created {len(costs)} batched cost objects")
Created 6 batched cost objects
# Initial guess: linear interpolation for states, zero controls.
t_interp = jnp.linspace(0, 1, n_steps + 1)[:, None]
initial_states = initial_state + t_interp * (target_state - initial_state)
initial_controls = jnp.zeros((n_steps, 1))

initial_vals = jaxls.VarValues.make(
    [
        state_vars.with_value(initial_states),
        control_vars.with_value(initial_controls),
    ]
)

# Create the problem.
problem = jaxls.LeastSquaresProblem(costs, [state_vars, control_vars])

# Visualize the problem structure structure.
problem.show()
# Analyze the problem.
problem = problem.analyze()
INFO     | Building optimization problem with 252 terms and 121 variables: 72 costs, 60 eq_zero, 120 leq_zero, 0 geq_zero
INFO     | Vectorizing constraint group with 60 constraints (constraint_leq_zero), 1 variables each: augmented_control_upper_bound
INFO     | Vectorizing constraint group with 60 constraints (constraint_eq_zero), 3 variables each: augmented_dynamics_constraint
INFO     | Vectorizing group with 12 costs, 1 variables each: boundary_cost
INFO     | Vectorizing group with 60 costs, 1 variables each: control_cost
INFO     | Vectorizing constraint group with 60 constraints (constraint_leq_zero), 1 variables each: augmented_control_lower_bound

Solving#

solution = problem.solve(initial_vals)
INFO     | Augmented Lagrangian: initial snorm=7.3573e-01, csupn=7.3573e-01, max_rho=1.0572e+05, constraint_dim=360
INFO     |  step #0: cost=10554.9854 lambd=0.0005 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 1805924.12500 (avg 7524.68408)
INFO     |      - boundary_cost(12): 10554.98535 (avg 219.89554)
INFO     |      - control_cost(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #1: cost=10554.9854 lambd=0.0010 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 1805924.12500 (avg 7524.68408)
INFO     |      - boundary_cost(12): 10554.98535 (avg 219.89554)
INFO     |      - control_cost(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #2: cost=10554.9854 lambd=0.0020 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 1805924.12500 (avg 7524.68408)
INFO     |      - boundary_cost(12): 10554.98535 (avg 219.89554)
INFO     |      - control_cost(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #3: cost=10554.9854 lambd=0.0040 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 1805924.12500 (avg 7524.68408)
INFO     |      - boundary_cost(12): 10554.98535 (avg 219.89554)
INFO     |      - control_cost(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #4: cost=10554.9854 lambd=0.0080 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 1805924.12500 (avg 7524.68408)
INFO     |      - boundary_cost(12): 10554.98535 (avg 219.89554)
INFO     |      - control_cost(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #5: cost=10554.9854 lambd=0.0160 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 1805924.12500 (avg 7524.68408)
INFO     |      - boundary_cost(12): 10554.98535 (avg 219.89554)
INFO     |      - control_cost(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #6: cost=10554.9854 lambd=0.0320 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 1805924.12500 (avg 7524.68408)
INFO     |      - boundary_cost(12): 10554.98535 (avg 219.89554)
INFO     |      - control_cost(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #7: cost=10554.9854 lambd=0.0640 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 1805924.12500 (avg 7524.68408)
INFO     |      - boundary_cost(12): 10554.98535 (avg 219.89554)
INFO     |      - control_cost(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #8: cost=10554.9854 lambd=0.1280 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 1805924.12500 (avg 7524.68408)
INFO     |      - boundary_cost(12): 10554.98535 (avg 219.89554)
INFO     |      - control_cost(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #9: cost=10554.9854 lambd=0.2560 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 1805924.12500 (avg 7524.68408)
INFO     |      - boundary_cost(12): 10554.98535 (avg 219.89554)
INFO     |      - control_cost(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #10: cost=10554.9854 lambd=0.5120 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 1805924.12500 (avg 7524.68408)
INFO     |      - boundary_cost(12): 10554.98535 (avg 219.89554)
INFO     |      - control_cost(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #11: cost=10554.9854 lambd=1.0240 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 1805924.12500 (avg 7524.68408)
INFO     |      - boundary_cost(12): 10554.98535 (avg 219.89554)
INFO     |      - control_cost(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #12: cost=10554.9854 lambd=2.0480 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 1805924.12500 (avg 7524.68408)
INFO     |      - boundary_cost(12): 10554.98535 (avg 219.89554)
INFO     |      - control_cost(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #13: cost=10554.9854 lambd=4.0960 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 1805924.12500 (avg 7524.68408)
INFO     |      - boundary_cost(12): 10554.98535 (avg 219.89554)
INFO     |      - control_cost(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #14: cost=10554.9854 lambd=8.1920 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 1805924.12500 (avg 7524.68408)
INFO     |      - boundary_cost(12): 10554.98535 (avg 219.89554)
INFO     |      - control_cost(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #15: cost=10554.9854 lambd=16.3840 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 1805924.12500 (avg 7524.68408)
INFO     |      - boundary_cost(12): 10554.98535 (avg 219.89554)
INFO     |      - control_cost(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.64e+05 cost_prev=1816479.1250 cost_new=1205266.7500
INFO     |  step #16: cost=3563.0605 lambd=8.1920 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 1201703.62500 (avg 5007.09863)
INFO     |      - boundary_cost(12): 2345.29688 (avg 48.86035)
INFO     |      - control_cost(60): 1217.76355 (avg 20.29606)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=2.58e+05 cost_prev=1205266.7500 cost_new=37901.0586
INFO     |  step #17: cost=1887.9652 lambd=4.0960 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 36013.09766 (avg 150.05458)
INFO     |      - boundary_cost(12): 705.00037 (avg 14.68751)
INFO     |      - control_cost(60): 1182.96484 (avg 19.71608)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=4.25e+04 cost_prev=37901.0586 cost_new=2399.1091
INFO     |  step #18: cost=1454.8644 lambd=2.0480 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 944.24457 (avg 3.93435)
INFO     |      - boundary_cost(12): 239.56528 (avg 4.99094)
INFO     |      - control_cost(60): 1215.29907 (avg 20.25499)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #19: cost=1454.8644 lambd=4.0960 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 944.24457 (avg 3.93435)
INFO     |      - boundary_cost(12): 239.56528 (avg 4.99094)
INFO     |      - control_cost(60): 1215.29907 (avg 20.25499)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #20: cost=1454.8644 lambd=8.1920 inexact_tol=1.0e-02
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 944.24457 (avg 3.93435)
INFO     |      - boundary_cost(12): 239.56528 (avg 4.99094)
INFO     |      - control_cost(60): 1215.29907 (avg 20.25499)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=3.15e+03 cost_prev=2399.1091 cost_new=1538.6353
INFO     |  step #21: cost=1296.8329 lambd=4.0960 inexact_tol=4.9e-03
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 241.80228 (avg 1.00751)
INFO     |      - boundary_cost(12): 85.79935 (avg 1.78749)
INFO     |      - control_cost(60): 1211.03357 (avg 20.18389)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #22: cost=1296.8329 lambd=8.1920 inexact_tol=4.9e-03
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 241.80228 (avg 1.00751)
INFO     |      - boundary_cost(12): 85.79935 (avg 1.78749)
INFO     |      - control_cost(60): 1211.03357 (avg 20.18389)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #23: cost=1296.8329 lambd=16.3840 inexact_tol=4.9e-03
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 241.80228 (avg 1.00751)
INFO     |      - boundary_cost(12): 85.79935 (avg 1.78749)
INFO     |      - control_cost(60): 1211.03357 (avg 20.18389)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #24: cost=1296.8329 lambd=32.7680 inexact_tol=4.9e-03
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 241.80228 (avg 1.00751)
INFO     |      - boundary_cost(12): 85.79935 (avg 1.78749)
INFO     |      - control_cost(60): 1211.03357 (avg 20.18389)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.76e+03 cost_prev=1538.6353 cost_new=1421.9773
INFO     |  step #25: cost=1283.0844 lambd=16.3840 inexact_tol=4.9e-03
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 138.89305 (avg 0.57872)
INFO     |      - boundary_cost(12): 71.25256 (avg 1.48443)
INFO     |      - control_cost(60): 1211.83179 (avg 20.19720)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #26: cost=1283.0844 lambd=32.7680 inexact_tol=4.9e-03
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 138.89305 (avg 0.57872)
INFO     |      - boundary_cost(12): 71.25256 (avg 1.48443)
INFO     |      - control_cost(60): 1211.83179 (avg 20.19720)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |  step #27: cost=1283.0844 lambd=65.5360 inexact_tol=4.9e-03
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 138.89305 (avg 0.57872)
INFO     |      - boundary_cost(12): 71.25256 (avg 1.48443)
INFO     |      - control_cost(60): 1211.83179 (avg 20.19720)
INFO     |      - augmented_control_lower_bound(60): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=3.25e+01 cost_prev=1421.9773 cost_new=1407.2819
INFO     |  step #28: cost=1277.0970 lambd=32.7680 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 127.42957 (avg 0.53096)
INFO     |      - boundary_cost(12): 65.77699 (avg 1.37035)
INFO     |      - control_cost(60): 1211.32007 (avg 20.18867)
INFO     |      - augmented_control_lower_bound(60): 2.75521 (avg 0.04592)
INFO     |      accepted=True ATb_norm=5.64e+02 cost_prev=1407.2819 cost_new=1376.2341
INFO     |  step #29: cost=1264.5354 lambd=16.3840 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 111.69856 (avg 0.46541)
INFO     |      - boundary_cost(12): 57.64870 (avg 1.20101)
INFO     |      - control_cost(60): 1206.88672 (avg 20.11478)
INFO     |      - augmented_control_lower_bound(60): 0.00011 (avg 0.00000)
INFO     |      accepted=True ATb_norm=3.06e+01 cost_prev=1376.2341 cost_new=1334.5265
INFO     |  step #30: cost=1243.3263 lambd=8.1920 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 91.20001 (avg 0.38000)
INFO     |      - boundary_cost(12): 46.94744 (avg 0.97807)
INFO     |      - control_cost(60): 1196.37891 (avg 19.93965)
INFO     |      - augmented_control_lower_bound(60): 0.00008 (avg 0.00000)
INFO     |      accepted=True ATb_norm=5.67e+01 cost_prev=1334.5265 cost_new=1282.3444
INFO     |  step #31: cost=1211.0182 lambd=4.0960 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 71.32619 (avg 0.29719)
INFO     |      - boundary_cost(12): 36.44006 (avg 0.75917)
INFO     |      - control_cost(60): 1174.57812 (avg 19.57630)
INFO     |      - augmented_control_lower_bound(60): 0.00005 (avg 0.00000)
INFO     |  step #32: cost=1211.0182 lambd=8.1920 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 71.32619 (avg 0.29719)
INFO     |      - boundary_cost(12): 36.44006 (avg 0.75917)
INFO     |      - control_cost(60): 1174.57812 (avg 19.57630)
INFO     |      - augmented_control_lower_bound(60): 0.00005 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.11e+02 cost_prev=1282.3444 cost_new=1247.3392
INFO     |  step #33: cost=1186.4778 lambd=4.0960 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 60.86132 (avg 0.25359)
INFO     |      - boundary_cost(12): 31.16541 (avg 0.64928)
INFO     |      - control_cost(60): 1155.31238 (avg 19.25521)
INFO     |      - augmented_control_lower_bound(60): 0.00003 (avg 0.00000)
INFO     |  step #34: cost=1186.4778 lambd=8.1920 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 60.86132 (avg 0.25359)
INFO     |      - boundary_cost(12): 31.16541 (avg 0.64928)
INFO     |      - control_cost(60): 1155.31238 (avg 19.25521)
INFO     |      - augmented_control_lower_bound(60): 0.00003 (avg 0.00000)
INFO     |  step #35: cost=1186.4778 lambd=16.3840 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 60.86132 (avg 0.25359)
INFO     |      - boundary_cost(12): 31.16541 (avg 0.64928)
INFO     |      - control_cost(60): 1155.31238 (avg 19.25521)
INFO     |      - augmented_control_lower_bound(60): 0.00003 (avg 0.00000)
INFO     |      accepted=True ATb_norm=6.84e+01 cost_prev=1247.3392 cost_new=1233.3654
INFO     |  step #36: cost=1176.0522 lambd=8.1920 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 57.31307 (avg 0.23880)
INFO     |      - boundary_cost(12): 29.21135 (avg 0.60857)
INFO     |      - control_cost(60): 1146.84094 (avg 19.11402)
INFO     |      - augmented_control_lower_bound(60): 0.00003 (avg 0.00000)
INFO     |  step #37: cost=1176.0522 lambd=16.3840 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 57.31307 (avg 0.23880)
INFO     |      - boundary_cost(12): 29.21135 (avg 0.60857)
INFO     |      - control_cost(60): 1146.84094 (avg 19.11402)
INFO     |      - augmented_control_lower_bound(60): 0.00003 (avg 0.00000)
INFO     |  step #38: cost=1176.0522 lambd=32.7680 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 57.31307 (avg 0.23880)
INFO     |      - boundary_cost(12): 29.21135 (avg 0.60857)
INFO     |      - control_cost(60): 1146.84094 (avg 19.11402)
INFO     |      - augmented_control_lower_bound(60): 0.00003 (avg 0.00000)
INFO     |  step #39: cost=1176.0522 lambd=65.5360 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 57.31307 (avg 0.23880)
INFO     |      - boundary_cost(12): 29.21135 (avg 0.60857)
INFO     |      - control_cost(60): 1146.84094 (avg 19.11402)
INFO     |      - augmented_control_lower_bound(60): 0.00003 (avg 0.00000)
INFO     |  step #40: cost=1176.0522 lambd=131.0720 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 57.31307 (avg 0.23880)
INFO     |      - boundary_cost(12): 29.21135 (avg 0.60857)
INFO     |      - control_cost(60): 1146.84094 (avg 19.11402)
INFO     |      - augmented_control_lower_bound(60): 0.00003 (avg 0.00000)
INFO     |  step #41: cost=1176.0522 lambd=262.1440 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 57.31307 (avg 0.23880)
INFO     |      - boundary_cost(12): 29.21135 (avg 0.60857)
INFO     |      - control_cost(60): 1146.84094 (avg 19.11402)
INFO     |      - augmented_control_lower_bound(60): 0.00003 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.75e+01 cost_prev=1233.3654 cost_new=1233.3387
INFO     |  step #42: cost=1175.4315 lambd=131.0720 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 57.11014 (avg 0.23796)
INFO     |      - boundary_cost(12): 29.08092 (avg 0.60585)
INFO     |      - control_cost(60): 1146.35059 (avg 19.10584)
INFO     |      - augmented_control_lower_bound(60): 0.79700 (avg 0.01328)
INFO     |      accepted=True ATb_norm=3.02e+02 cost_prev=1233.3387 cost_new=1231.0483
INFO     |  step #43: cost=1173.9244 lambd=65.5360 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 57.12369 (avg 0.23802)
INFO     |      - boundary_cost(12): 28.93181 (avg 0.60275)
INFO     |      - control_cost(60): 1144.99268 (avg 19.08321)
INFO     |      - augmented_control_lower_bound(60): 0.00010 (avg 0.00000)
INFO     |      accepted=True ATb_norm=9.92e+00 cost_prev=1231.0483 cost_new=1228.1472
INFO     |  step #44: cost=1171.2439 lambd=32.7680 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 56.90316 (avg 0.23710)
INFO     |      - boundary_cost(12): 28.71295 (avg 0.59819)
INFO     |      - control_cost(60): 1142.53101 (avg 19.04218)
INFO     |      - augmented_control_lower_bound(60): 0.00009 (avg 0.00000)
INFO     |      accepted=True ATb_norm=9.74e+00 cost_prev=1228.1472 cost_new=1222.7347
INFO     |  step #45: cost=1166.3917 lambd=16.3840 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 56.34296 (avg 0.23476)
INFO     |      - boundary_cost(12): 28.35230 (avg 0.59067)
INFO     |      - control_cost(60): 1138.03943 (avg 18.96733)
INFO     |      - augmented_control_lower_bound(60): 0.00010 (avg 0.00000)
INFO     |      accepted=True ATb_norm=9.71e+00 cost_prev=1222.7347 cost_new=1213.1960
INFO     |  step #46: cost=1157.9268 lambd=8.1920 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 55.26915 (avg 0.23029)
INFO     |      - boundary_cost(12): 27.75149 (avg 0.57816)
INFO     |      - control_cost(60): 1130.17529 (avg 18.83626)
INFO     |      - augmented_control_lower_bound(60): 0.00010 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.25e+01 cost_prev=1213.1960 cost_new=1197.9978
INFO     |  step #47: cost=1144.2170 lambd=4.0960 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.78064 (avg 0.22409)
INFO     |      - boundary_cost(12): 26.87141 (avg 0.55982)
INFO     |      - control_cost(60): 1117.34558 (avg 18.62243)
INFO     |      - augmented_control_lower_bound(60): 0.00011 (avg 0.00000)
INFO     |  step #48: cost=1144.2170 lambd=8.1920 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.78064 (avg 0.22409)
INFO     |      - boundary_cost(12): 26.87141 (avg 0.55982)
INFO     |      - control_cost(60): 1117.34558 (avg 18.62243)
INFO     |      - augmented_control_lower_bound(60): 0.00011 (avg 0.00000)
INFO     |  step #49: cost=1144.2170 lambd=16.3840 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.78064 (avg 0.22409)
INFO     |      - boundary_cost(12): 26.87141 (avg 0.55982)
INFO     |      - control_cost(60): 1117.34558 (avg 18.62243)
INFO     |      - augmented_control_lower_bound(60): 0.00011 (avg 0.00000)
INFO     |      accepted=True ATb_norm=2.84e+01 cost_prev=1197.9978 cost_new=1191.4976
INFO     |  step #50: cost=1138.3524 lambd=8.1920 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.14507 (avg 0.22144)
INFO     |      - boundary_cost(12): 26.34883 (avg 0.54893)
INFO     |      - control_cost(60): 1112.00354 (avg 18.53339)
INFO     |      - augmented_control_lower_bound(60): 0.00010 (avg 0.00000)
INFO     |  step #51: cost=1138.3524 lambd=16.3840 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.14507 (avg 0.22144)
INFO     |      - boundary_cost(12): 26.34883 (avg 0.54893)
INFO     |      - control_cost(60): 1112.00354 (avg 18.53339)
INFO     |      - augmented_control_lower_bound(60): 0.00010 (avg 0.00000)
INFO     |  step #52: cost=1138.3524 lambd=32.7680 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.14507 (avg 0.22144)
INFO     |      - boundary_cost(12): 26.34883 (avg 0.54893)
INFO     |      - control_cost(60): 1112.00354 (avg 18.53339)
INFO     |      - augmented_control_lower_bound(60): 0.00010 (avg 0.00000)
INFO     |      accepted=True ATb_norm=9.36e+00 cost_prev=1191.4976 cost_new=1188.5089
INFO     |  step #53: cost=1135.6526 lambd=16.3840 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 52.85628 (avg 0.22023)
INFO     |      - boundary_cost(12): 26.09446 (avg 0.54363)
INFO     |      - control_cost(60): 1109.55811 (avg 18.49264)
INFO     |      - augmented_control_lower_bound(60): 0.00010 (avg 0.00000)
INFO     |  step #54: cost=1135.6526 lambd=32.7680 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 52.85628 (avg 0.22023)
INFO     |      - boundary_cost(12): 26.09446 (avg 0.54363)
INFO     |      - control_cost(60): 1109.55811 (avg 18.49264)
INFO     |      - augmented_control_lower_bound(60): 0.00010 (avg 0.00000)
INFO     |  step #55: cost=1135.6526 lambd=65.5360 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 52.85628 (avg 0.22023)
INFO     |      - boundary_cost(12): 26.09446 (avg 0.54363)
INFO     |      - control_cost(60): 1109.55811 (avg 18.49264)
INFO     |      - augmented_control_lower_bound(60): 0.00010 (avg 0.00000)
INFO     |      accepted=True ATb_norm=7.11e+00 cost_prev=1188.5089 cost_new=1187.0741
INFO     |  step #56: cost=1134.3571 lambd=32.7680 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 52.71685 (avg 0.21965)
INFO     |      - boundary_cost(12): 25.97153 (avg 0.54107)
INFO     |      - control_cost(60): 1108.38550 (avg 18.47309)
INFO     |      - augmented_control_lower_bound(60): 0.00010 (avg 0.00000)
INFO     |  step #57: cost=1134.3571 lambd=65.5360 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 52.71685 (avg 0.21965)
INFO     |      - boundary_cost(12): 25.97153 (avg 0.54107)
INFO     |      - control_cost(60): 1108.38550 (avg 18.47309)
INFO     |      - augmented_control_lower_bound(60): 0.00010 (avg 0.00000)
INFO     |  step #58: cost=1134.3571 lambd=131.0720 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 52.71685 (avg 0.21965)
INFO     |      - boundary_cost(12): 25.97153 (avg 0.54107)
INFO     |      - control_cost(60): 1108.38550 (avg 18.47309)
INFO     |      - augmented_control_lower_bound(60): 0.00010 (avg 0.00000)
INFO     |  step #59: cost=1134.3571 lambd=262.1440 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 52.71685 (avg 0.21965)
INFO     |      - boundary_cost(12): 25.97153 (avg 0.54107)
INFO     |      - control_cost(60): 1108.38550 (avg 18.47309)
INFO     |      - augmented_control_lower_bound(60): 0.00010 (avg 0.00000)
INFO     |      accepted=True ATb_norm=6.85e+00 cost_prev=1187.0741 cost_new=1186.7456
INFO     |  step #60: cost=1134.0378 lambd=131.0720 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 52.68249 (avg 0.21951)
INFO     |      - boundary_cost(12): 25.94023 (avg 0.54042)
INFO     |      - control_cost(60): 1108.09766 (avg 18.46830)
INFO     |      - augmented_control_lower_bound(60): 0.02531 (avg 0.00042)
INFO     |      accepted=True ATb_norm=5.27e+01 cost_prev=1186.7456 cost_new=1186.0742
INFO     |  step #61: cost=1133.2061 lambd=65.5360 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 52.86800 (avg 0.22028)
INFO     |      - boundary_cost(12): 25.93231 (avg 0.54026)
INFO     |      - control_cost(60): 1107.27380 (avg 18.45457)
INFO     |      - augmented_control_lower_bound(60): 0.00014 (avg 0.00000)
INFO     |      accepted=True ATb_norm=6.51e+00 cost_prev=1186.0742 cost_new=1184.8181
INFO     |  step #62: cost=1131.7074 lambd=32.7680 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.11071 (avg 0.22129)
INFO     |      - boundary_cost(12): 25.95303 (avg 0.54069)
INFO     |      - control_cost(60): 1105.75439 (avg 18.42924)
INFO     |      - augmented_control_lower_bound(60): 0.00015 (avg 0.00000)
INFO     |      accepted=True ATb_norm=6.40e+00 cost_prev=1184.8181 cost_new=1182.4479
INFO     |  step #63: cost=1129.0564 lambd=16.3840 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.39143 (avg 0.22246)
INFO     |      - boundary_cost(12): 26.00887 (avg 0.54185)
INFO     |      - control_cost(60): 1103.04749 (avg 18.38413)
INFO     |      - augmented_control_lower_bound(60): 0.00016 (avg 0.00000)
INFO     |      accepted=True ATb_norm=6.30e+00 cost_prev=1182.4479 cost_new=1178.1984
INFO     |  step #64: cost=1124.5724 lambd=8.1920 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.62589 (avg 0.22344)
INFO     |      - boundary_cost(12): 26.06059 (avg 0.54293)
INFO     |      - control_cost(60): 1098.51184 (avg 18.30853)
INFO     |      - augmented_control_lower_bound(60): 0.00017 (avg 0.00000)
INFO     |      accepted=True ATb_norm=6.95e+00 cost_prev=1178.1984 cost_new=1171.2988
INFO     |  step #65: cost=1117.6056 lambd=4.0960 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.69297 (avg 0.22372)
INFO     |      - boundary_cost(12): 25.96902 (avg 0.54102)
INFO     |      - control_cost(60): 1091.63660 (avg 18.19394)
INFO     |      - augmented_control_lower_bound(60): 0.00018 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.29e+01 cost_prev=1171.2988 cost_new=1161.9695
INFO     |  step #66: cost=1108.3529 lambd=2.0480 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.61647 (avg 0.22340)
INFO     |      - boundary_cost(12): 25.64678 (avg 0.53431)
INFO     |      - control_cost(60): 1082.70618 (avg 18.04510)
INFO     |      - augmented_control_lower_bound(60): 0.00018 (avg 0.00000)
INFO     |      accepted=True ATb_norm=3.04e+01 cost_prev=1161.9695 cost_new=1152.7798
INFO     |  step #67: cost=1099.3597 lambd=1.0240 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.41990 (avg 0.22258)
INFO     |      - boundary_cost(12): 25.16574 (avg 0.52429)
INFO     |      - control_cost(60): 1074.19397 (avg 17.90323)
INFO     |      - augmented_control_lower_bound(60): 0.00017 (avg 0.00000)
INFO     |      accepted=True ATb_norm=5.42e+01 cost_prev=1152.7798 cost_new=1147.2076
INFO     |  step #68: cost=1094.0739 lambd=0.5120 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.13367 (avg 0.22139)
INFO     |      - boundary_cost(12): 24.63691 (avg 0.51327)
INFO     |      - control_cost(60): 1069.43689 (avg 17.82395)
INFO     |      - augmented_control_lower_bound(60): 0.00017 (avg 0.00000)
INFO     |  step #69: cost=1094.0739 lambd=1.0240 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.13367 (avg 0.22139)
INFO     |      - boundary_cost(12): 24.63691 (avg 0.51327)
INFO     |      - control_cost(60): 1069.43689 (avg 17.82395)
INFO     |      - augmented_control_lower_bound(60): 0.00017 (avg 0.00000)
INFO     |  step #70: cost=1094.0739 lambd=2.0480 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.13367 (avg 0.22139)
INFO     |      - boundary_cost(12): 24.63691 (avg 0.51327)
INFO     |      - control_cost(60): 1069.43689 (avg 17.82395)
INFO     |      - augmented_control_lower_bound(60): 0.00017 (avg 0.00000)
INFO     |  step #71: cost=1094.0739 lambd=4.0960 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.13367 (avg 0.22139)
INFO     |      - boundary_cost(12): 24.63691 (avg 0.51327)
INFO     |      - control_cost(60): 1069.43689 (avg 17.82395)
INFO     |      - augmented_control_lower_bound(60): 0.00017 (avg 0.00000)
INFO     |  step #72: cost=1094.0739 lambd=8.1920 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.13367 (avg 0.22139)
INFO     |      - boundary_cost(12): 24.63691 (avg 0.51327)
INFO     |      - control_cost(60): 1069.43689 (avg 17.82395)
INFO     |      - augmented_control_lower_bound(60): 0.00017 (avg 0.00000)
INFO     |      accepted=True ATb_norm=5.57e+01 cost_prev=1147.2076 cost_new=1146.7330
INFO     |  step #73: cost=1093.7294 lambd=4.0960 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.00356 (avg 0.22085)
INFO     |      - boundary_cost(12): 24.33462 (avg 0.50697)
INFO     |      - control_cost(60): 1069.39478 (avg 17.82325)
INFO     |      - augmented_control_lower_bound(60): 0.00018 (avg 0.00000)
INFO     |  step #74: cost=1093.7294 lambd=8.1920 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.00356 (avg 0.22085)
INFO     |      - boundary_cost(12): 24.33462 (avg 0.50697)
INFO     |      - control_cost(60): 1069.39478 (avg 17.82325)
INFO     |      - augmented_control_lower_bound(60): 0.00018 (avg 0.00000)
INFO     |  step #75: cost=1093.7294 lambd=16.3840 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.00356 (avg 0.22085)
INFO     |      - boundary_cost(12): 24.33462 (avg 0.50697)
INFO     |      - control_cost(60): 1069.39478 (avg 17.82325)
INFO     |      - augmented_control_lower_bound(60): 0.00018 (avg 0.00000)
INFO     |  step #76: cost=1093.7294 lambd=32.7680 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.00356 (avg 0.22085)
INFO     |      - boundary_cost(12): 24.33462 (avg 0.50697)
INFO     |      - control_cost(60): 1069.39478 (avg 17.82325)
INFO     |      - augmented_control_lower_bound(60): 0.00018 (avg 0.00000)
INFO     |  step #77: cost=1093.7294 lambd=65.5360 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 53.00356 (avg 0.22085)
INFO     |      - boundary_cost(12): 24.33462 (avg 0.50697)
INFO     |      - control_cost(60): 1069.39478 (avg 17.82325)
INFO     |      - augmented_control_lower_bound(60): 0.00018 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.61e+00 cost_prev=1146.7330 cost_new=1146.6871
INFO     |  step #78: cost=1093.6888 lambd=32.7680 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 52.99823 (avg 0.22083)
INFO     |      - boundary_cost(12): 24.29304 (avg 0.50611)
INFO     |      - control_cost(60): 1069.39575 (avg 17.82326)
INFO     |      - augmented_control_lower_bound(60): 0.00018 (avg 0.00000)
INFO     |  step #79: cost=1093.6888 lambd=65.5360 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 52.99823 (avg 0.22083)
INFO     |      - boundary_cost(12): 24.29304 (avg 0.50611)
INFO     |      - control_cost(60): 1069.39575 (avg 17.82326)
INFO     |      - augmented_control_lower_bound(60): 0.00018 (avg 0.00000)
INFO     |  step #80: cost=1093.6888 lambd=131.0720 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 52.99823 (avg 0.22083)
INFO     |      - boundary_cost(12): 24.29304 (avg 0.50611)
INFO     |      - control_cost(60): 1069.39575 (avg 17.82326)
INFO     |      - augmented_control_lower_bound(60): 0.00018 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.25e+00 cost_prev=1146.6871 cost_new=1146.6644
INFO     |  step #81: cost=1093.6729 lambd=65.5360 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 52.99146 (avg 0.22080)
INFO     |      - boundary_cost(12): 24.27709 (avg 0.50577)
INFO     |      - control_cost(60): 1069.39575 (avg 17.82326)
INFO     |      - augmented_control_lower_bound(60): 0.00018 (avg 0.00000)
INFO     |  step #82: cost=1093.6729 lambd=131.0720 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 52.99146 (avg 0.22080)
INFO     |      - boundary_cost(12): 24.27709 (avg 0.50577)
INFO     |      - control_cost(60): 1069.39575 (avg 17.82326)
INFO     |      - augmented_control_lower_bound(60): 0.00018 (avg 0.00000)
INFO     |  step #83: cost=1093.6729 lambd=262.1440 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 52.99146 (avg 0.22080)
INFO     |      - boundary_cost(12): 24.27709 (avg 0.50577)
INFO     |      - control_cost(60): 1069.39575 (avg 17.82326)
INFO     |      - augmented_control_lower_bound(60): 0.00018 (avg 0.00000)
INFO     |      accepted=True ATb_norm=1.24e+00 cost_prev=1146.6644 cost_new=1146.6539
INFO     |  AL update: snorm=5.1765e-03, csupn=5.1765e-03, max_rho=4.2288e+05
INFO     |  step #84: cost=1093.6654 lambd=131.0720 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 211.95322 (avg 0.88314)
INFO     |      - boundary_cost(12): 24.26969 (avg 0.50562)
INFO     |      - control_cost(60): 1069.39575 (avg 17.82326)
INFO     |      - augmented_control_lower_bound(60): 0.00112 (avg 0.00002)
INFO     |  step #85: cost=1093.6654 lambd=262.1440 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 211.95322 (avg 0.88314)
INFO     |      - boundary_cost(12): 24.26969 (avg 0.50562)
INFO     |      - control_cost(60): 1069.39575 (avg 17.82326)
INFO     |      - augmented_control_lower_bound(60): 0.00112 (avg 0.00002)
INFO     |  step #86: cost=1093.6654 lambd=524.2880 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 211.95322 (avg 0.88314)
INFO     |      - boundary_cost(12): 24.26969 (avg 0.50562)
INFO     |      - control_cost(60): 1069.39575 (avg 17.82326)
INFO     |      - augmented_control_lower_bound(60): 0.00112 (avg 0.00002)
INFO     |  step #87: cost=1093.6654 lambd=1048.5760 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 211.95322 (avg 0.88314)
INFO     |      - boundary_cost(12): 24.26969 (avg 0.50562)
INFO     |      - control_cost(60): 1069.39575 (avg 17.82326)
INFO     |      - augmented_control_lower_bound(60): 0.00112 (avg 0.00002)
INFO     |  step #88: cost=1093.6654 lambd=2097.1521 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 211.95322 (avg 0.88314)
INFO     |      - boundary_cost(12): 24.26969 (avg 0.50562)
INFO     |      - control_cost(60): 1069.39575 (avg 17.82326)
INFO     |      - augmented_control_lower_bound(60): 0.00112 (avg 0.00002)
INFO     |      accepted=True ATb_norm=4.95e+02 cost_prev=1305.6198 cost_new=1301.2649
INFO     |  step #89: cost=1105.2738 lambd=1048.5760 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 193.44421 (avg 0.80602)
INFO     |      - boundary_cost(12): 35.24968 (avg 0.73437)
INFO     |      - control_cost(60): 1070.02417 (avg 17.83374)
INFO     |      - augmented_control_lower_bound(60): 2.54674 (avg 0.04245)
INFO     |      accepted=True ATb_norm=5.35e+02 cost_prev=1301.2649 cost_new=1295.5409
INFO     |  step #90: cost=1111.5898 lambd=524.2880 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 183.94785 (avg 0.76645)
INFO     |      - boundary_cost(12): 40.65934 (avg 0.84707)
INFO     |      - control_cost(60): 1070.93054 (avg 17.84884)
INFO     |      - augmented_control_lower_bound(60): 0.00318 (avg 0.00005)
INFO     |      accepted=True ATb_norm=3.82e+01 cost_prev=1295.5409 cost_new=1292.2277
INFO     |  step #91: cost=1118.1865 lambd=262.1440 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 174.03940 (avg 0.72516)
INFO     |      - boundary_cost(12): 45.34199 (avg 0.94462)
INFO     |      - control_cost(60): 1072.84448 (avg 17.88074)
INFO     |      - augmented_control_lower_bound(60): 0.00190 (avg 0.00003)
INFO     |      accepted=True ATb_norm=2.74e+01 cost_prev=1292.2277 cost_new=1288.2720
INFO     |  step #92: cost=1125.3481 lambd=131.0720 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 162.92201 (avg 0.67884)
INFO     |      - boundary_cost(12): 48.94568 (avg 1.01970)
INFO     |      - control_cost(60): 1076.40247 (avg 17.94004)
INFO     |      - augmented_control_lower_bound(60): 0.00165 (avg 0.00003)
INFO     |      accepted=True ATb_norm=2.17e+01 cost_prev=1288.2720 cost_new=1282.9974
INFO     |  step #93: cost=1133.6633 lambd=65.5360 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 149.33278 (avg 0.62222)
INFO     |      - boundary_cost(12): 51.02476 (avg 1.06302)
INFO     |      - control_cost(60): 1082.63855 (avg 18.04398)
INFO     |      - augmented_control_lower_bound(60): 0.00138 (avg 0.00002)
INFO     |  step #94: cost=1133.6633 lambd=131.0720 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 149.33278 (avg 0.62222)
INFO     |      - boundary_cost(12): 51.02476 (avg 1.06302)
INFO     |      - control_cost(60): 1082.63855 (avg 18.04398)
INFO     |      - augmented_control_lower_bound(60): 0.00138 (avg 0.00002)
INFO     |      accepted=True ATb_norm=1.78e+01 cost_prev=1282.9974 cost_new=1279.1952
INFO     |  step #95: cost=1139.7399 lambd=65.5360 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 139.45413 (avg 0.58106)
INFO     |      - boundary_cost(12): 51.56386 (avg 1.07425)
INFO     |      - control_cost(60): 1088.17603 (avg 18.13627)
INFO     |      - augmented_control_lower_bound(60): 0.00118 (avg 0.00002)
INFO     |  step #96: cost=1139.7399 lambd=131.0720 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 139.45413 (avg 0.58106)
INFO     |      - boundary_cost(12): 51.56386 (avg 1.07425)
INFO     |      - control_cost(60): 1088.17603 (avg 18.13627)
INFO     |      - augmented_control_lower_bound(60): 0.00118 (avg 0.00002)
INFO     |  step #97: cost=1139.7399 lambd=262.1440 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 139.45413 (avg 0.58106)
INFO     |      - boundary_cost(12): 51.56386 (avg 1.07425)
INFO     |      - control_cost(60): 1088.17603 (avg 18.13627)
INFO     |      - augmented_control_lower_bound(60): 0.00118 (avg 0.00002)
INFO     |  step #98: cost=1139.7399 lambd=524.2880 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 139.45413 (avg 0.58106)
INFO     |      - boundary_cost(12): 51.56386 (avg 1.07425)
INFO     |      - control_cost(60): 1088.17603 (avg 18.13627)
INFO     |      - augmented_control_lower_bound(60): 0.00118 (avg 0.00002)
INFO     |      accepted=True ATb_norm=1.52e+01 cost_prev=1279.1952 cost_new=1278.4221
INFO     |  step #99: cost=1141.1459 lambd=262.1440 inexact_tol=3.0e-04
INFO     |      - augmented_control_upper_bound(60): 0.00000 (avg 0.00000)
INFO     |      - augmented_dynamics_constraint(60): 137.21538 (avg 0.57173)
INFO     |      - boundary_cost(12): 51.62531 (avg 1.07553)
INFO     |      - control_cost(60): 1089.52051 (avg 18.15868)
INFO     |      - augmented_control_lower_bound(60): 0.06096 (avg 0.00102)
INFO     |      accepted=True ATb_norm=7.87e+01 cost_prev=1278.4221 cost_new=1277.1542
INFO     | Terminated @ iteration #100: cost=1143.2786 criteria=[0 0 0], term_deltas=1.9e-03,7.8e+01,1.3e-04

Visualization#

Extract solution trajectories and create animated visualization:

# Extract solution trajectories.
states = solution[state_vars]  # Shape: (n_steps+1, 4)
controls = solution[control_vars]  # Shape: (n_steps, 1)

times = jnp.linspace(0, total_time, n_steps + 1)
control_times = jnp.linspace(0, total_time - dt, n_steps)

print(f"Final state: x={float(states[-1, 0]):.3f}, theta={float(states[-1, 1]):.3f}")
print(f"Target theta (pi): {jnp.pi:.3f}")
Final state: x=0.001, theta=3.141
Target theta (pi): 3.142

Hide code cell source

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import HTML

# Animation of cart-pole swing-up.
cart_y_offset = 1.2  # Raise cart so pole can swing freely above ground.
wheel_radius = 0.05


def create_cart_pole_frame(x: float, theta: float) -> dict:
    """Create a single frame of cart-pole visualization."""
    cart_width, cart_height = 0.4, 0.15
    pole_len = 2 * length  # Full pole length.
    pole_x = x + pole_len * jnp.sin(theta)
    pole_y = cart_y_offset - pole_len * jnp.cos(theta)

    # Wheel positions.
    wheel_y = cart_y_offset - cart_height / 2 - wheel_radius

    return dict(
        cart_x=[
            x - cart_width / 2,
            x + cart_width / 2,
            x + cart_width / 2,
            x - cart_width / 2,
            x - cart_width / 2,
        ],
        cart_y=[
            cart_y_offset - cart_height / 2,
            cart_y_offset - cart_height / 2,
            cart_y_offset + cart_height / 2,
            cart_y_offset + cart_height / 2,
            cart_y_offset - cart_height / 2,
        ],
        pole_x=[x, float(pole_x)],
        pole_y=[cart_y_offset, float(pole_y)],
        mass_x=[float(pole_x)],
        mass_y=[float(pole_y)],
        wheel1_x=[x - cart_width / 3],
        wheel1_y=[wheel_y],
        wheel2_x=[x + cart_width / 3],
        wheel2_y=[wheel_y],
    )


# Create animation frames (use every 2nd timestep for performance).
frame_step = 2
frame_indices = list(range(0, n_steps + 1, frame_step))
frames = []
for i in frame_indices:
    frame_data = create_cart_pole_frame(float(states[i, 0]), float(states[i, 1]))
    frames.append(
        go.Frame(
            data=[
                go.Scatter(
                    x=frame_data["cart_x"],
                    y=frame_data["cart_y"],
                    fill="toself",
                    fillcolor="steelblue",
                    line=dict(color="darkblue", width=2),
                ),
                go.Scatter(
                    x=frame_data["pole_x"],
                    y=frame_data["pole_y"],
                    mode="lines",
                    line=dict(color="sienna", width=8),
                ),
                go.Scatter(
                    x=frame_data["mass_x"],
                    y=frame_data["mass_y"],
                    mode="markers",
                    marker=dict(
                        size=18, color="sienna", line=dict(color="black", width=1)
                    ),
                ),
                go.Scatter(
                    x=frame_data["wheel1_x"],
                    y=frame_data["wheel1_y"],
                    mode="markers",
                    marker=dict(size=12, color="black", symbol="circle"),
                ),
                go.Scatter(
                    x=frame_data["wheel2_x"],
                    y=frame_data["wheel2_y"],
                    mode="markers",
                    marker=dict(size=12, color="black", symbol="circle"),
                ),
            ],
            name=str(i),
        )
    )

# Initial frame.
init_frame = create_cart_pole_frame(float(states[0, 0]), float(states[0, 1]))
rail_y = cart_y_offset - 0.15 / 2 - wheel_radius * 2

fig_anim = go.Figure(
    data=[
        go.Scatter(
            x=init_frame["cart_x"],
            y=init_frame["cart_y"],
            fill="toself",
            fillcolor="steelblue",
            line=dict(color="darkblue", width=2),
            name="Cart",
        ),
        go.Scatter(
            x=init_frame["pole_x"],
            y=init_frame["pole_y"],
            mode="lines",
            line=dict(color="sienna", width=8),
            name="Pole",
        ),
        go.Scatter(
            x=init_frame["mass_x"],
            y=init_frame["mass_y"],
            mode="markers",
            marker=dict(size=18, color="sienna", line=dict(color="black", width=1)),
            showlegend=False,
        ),
        go.Scatter(
            x=init_frame["wheel1_x"],
            y=init_frame["wheel1_y"],
            mode="markers",
            marker=dict(size=12, color="black", symbol="circle"),
            showlegend=False,
        ),
        go.Scatter(
            x=init_frame["wheel2_x"],
            y=init_frame["wheel2_y"],
            mode="markers",
            marker=dict(size=12, color="black", symbol="circle"),
            showlegend=False,
        ),
        go.Scatter(
            x=[-2.5, 2.5],
            y=[rail_y, rail_y],
            mode="lines",
            line=dict(color="gray", width=6),
            name="Rail",
        ),
    ],
    frames=frames,
    layout=go.Layout(
        title="Cart-Pole (Collocation)",
        xaxis=dict(
            range=[-2.5, 2.5], title="x (m)", constrain="domain", showgrid=False
        ),
        yaxis=dict(
            range=[-0.1, 2.5],
            title="y (m)",
            scaleanchor="x",
            scaleratio=1,
            showgrid=False,
        ),
        updatemenus=[
            dict(
                type="buttons",
                showactive=False,
                y=1.15,
                x=0.5,
                xanchor="center",
                buttons=[
                    dict(
                        label="Play",
                        method="animate",
                        args=[
                            None,
                            dict(
                                frame=dict(duration=50, redraw=True),
                                fromcurrent=True,
                                transition=dict(duration=0),
                            ),
                        ],
                    ),
                    dict(
                        label="Pause",
                        method="animate",
                        args=[
                            [None],
                            dict(
                                frame=dict(duration=0, redraw=False),
                                mode="immediate",
                            ),
                        ],
                    ),
                ],
            )
        ],
        sliders=[
            dict(
                active=0,
                yanchor="top",
                xanchor="left",
                currentvalue=dict(
                    prefix="Time: ",
                    suffix="s",
                    visible=True,
                    xanchor="center",
                    offset=20,
                ),
                pad=dict(b=10, t=60),
                steps=[
                    dict(
                        args=[
                            [str(i)],
                            dict(
                                frame=dict(duration=0, redraw=True),
                                mode="immediate",
                                transition=dict(duration=0),
                            ),
                        ],
                        label=f"{float(times[i]):.2f}",
                        method="animate",
                    )
                    for i in frame_indices
                ],
                x=0.1,
                y=0,
                len=0.8,
            )
        ],
        height=500,
        showlegend=False,
        margin=dict(t=80, b=100),
        plot_bgcolor="white",
    ),
)
HTML(fig_anim.to_html(full_html=False, include_plotlyjs="cdn", auto_play=False))

Hide code cell source

# State and control trajectories.
fig_traj = make_subplots(
    rows=2,
    cols=2,
    subplot_titles=("Cart Position", "Pole Angle", "Velocities", "Control Force"),
    vertical_spacing=0.15,
    horizontal_spacing=0.1,
)

# Cart position.
fig_traj.add_trace(
    go.Scatter(
        x=[float(t) for t in times],
        y=[float(s) for s in states[:, 0]],
        mode="lines",
        line=dict(color="steelblue", width=2),
        name="x",
    ),
    row=1,
    col=1,
)

# Pole angle.
fig_traj.add_trace(
    go.Scatter(
        x=[float(t) for t in times],
        y=[float(s) for s in states[:, 1]],
        mode="lines",
        line=dict(color="coral", width=2),
        name="θ",
    ),
    row=1,
    col=2,
)
fig_traj.add_hline(
    y=float(jnp.pi),
    line_dash="dash",
    line_color="gray",
    row=1,
    col=2,
    annotation_text="π",
)

# Velocities.
fig_traj.add_trace(
    go.Scatter(
        x=[float(t) for t in times],
        y=[float(s) for s in states[:, 2]],
        mode="lines",
        line=dict(color="steelblue", width=2),
        name="ẋ",
    ),
    row=2,
    col=1,
)
fig_traj.add_trace(
    go.Scatter(
        x=[float(t) for t in times],
        y=[float(s) for s in states[:, 3]],
        mode="lines",
        line=dict(color="coral", width=2, dash="dash"),
        name="θ̇",
    ),
    row=2,
    col=1,
)

# Control force.
fig_traj.add_trace(
    go.Scatter(
        x=[float(t) for t in control_times],
        y=[float(u) for u in controls[:, 0]],
        mode="lines",
        line=dict(color="forestgreen", width=2),
        name="F",
        fill="tozeroy",
        fillcolor="rgba(34, 139, 34, 0.2)",
    ),
    row=2,
    col=2,
)
fig_traj.add_hline(
    y=force_limit, line_dash="dash", line_color="red", opacity=0.5, row=2, col=2
)
fig_traj.add_hline(
    y=-force_limit, line_dash="dash", line_color="red", opacity=0.5, row=2, col=2
)

# Axis labels.
fig_traj.update_xaxes(title_text="Time (s)", row=1, col=1)
fig_traj.update_yaxes(title_text="Position (m)", row=1, col=1)
fig_traj.update_xaxes(title_text="Time (s)", row=1, col=2)
fig_traj.update_yaxes(title_text="Angle (rad)", row=1, col=2)
fig_traj.update_xaxes(title_text="Time (s)", row=2, col=1)
fig_traj.update_yaxes(title_text="Velocity", row=2, col=1)
fig_traj.update_xaxes(title_text="Time (s)", row=2, col=2)
fig_traj.update_yaxes(title_text="Force (N)", row=2, col=2)

fig_traj.update_layout(height=500, showlegend=False, margin=dict(t=40, b=40))
HTML(fig_traj.to_html(full_html=False, include_plotlyjs="cdn"))