Cart-pole (shooting)#
In this notebook, we solve a cart-pole swing-up problem using single shooting: finding control inputs that swing a pendulum from hanging down to balancing upright.
Note
This notebook demonstrates the shooting method, where we optimize only the control trajectory and compute states by forward simulation. For contrast, see Cart-pole (collocation), which uses direct collocation to optimize both states and controls with dynamics as constraints.
Key difference from direct collocation:
Shooting: Optimizes only controls; states are computed by simulating dynamics forward
Direct collocation: Optimizes both states and controls, with dynamics enforced as equality constraints
Shooting is simpler but can struggle with long time horizons due to sensitivity of the initial value problem.
Features used:
Varfor control trajectory (single variable containing all controls)@jaxls.Cost.factoryfor terminal state cost and control regularizationForward simulation with
jax.lax.scan
import jax
import jax.numpy as jnp
import jaxls
Cart-pole dynamics#
The cart-pole system has state \([x, \theta, \dot{x}, \dot{\theta}]\) where \(x\) is cart position and \(\theta\) is pole angle (0 = hanging down, \(\pi\) = upright). The control input is a horizontal force on the cart.
# Physical parameters.
m_cart = 1.0 # Cart mass (kg)
m_pole = 0.1 # Pole mass (kg)
length = 0.5 # Pole half-length (m)
gravity = 9.81 # Gravitational acceleration (m/s^2)
# Trajectory parameters.
n_swing = 50 # Timesteps for swing-up phase
n_hold = 10 # Timesteps to hold upright position
n_steps = n_swing + n_hold # Total timesteps
dt = 0.05 # Time step (s)
total_time = n_steps * dt
print(f"Total trajectory time: {total_time:.2f}s ({n_swing} swing + {n_hold} hold)")
Total trajectory time: 3.00s (50 swing + 10 hold)
@jax.jit
def cart_pole_dynamics(state: jax.Array, force: jax.Array) -> jax.Array:
"""Compute state derivatives for the cart-pole system.
State: [x, theta, x_dot, theta_dot]
theta = 0: pole hanging down, theta = pi: pole upright
Args:
state: Current state [x, theta, x_dot, theta_dot]
force: Control force [f]
Returns:
State derivatives [x_dot, theta_dot, x_ddot, theta_ddot]
"""
x, theta, x_dot, theta_dot = state
f = force[0]
sin_th = jnp.sin(theta)
cos_th = jnp.cos(theta)
# Total mass.
total_mass = m_cart + m_pole
pole_mass_length = m_pole * length
# Equations of motion for theta=0 being DOWN (stable equilibrium)
temp = (f + pole_mass_length * theta_dot**2 * sin_th) / total_mass
theta_ddot = (-gravity * sin_th - cos_th * temp) / (
length * (4.0 / 3.0 - m_pole * cos_th**2 / total_mass)
)
x_ddot = temp - pole_mass_length * theta_ddot * cos_th / total_mass
return jnp.array([x_dot, theta_dot, x_ddot, theta_ddot])
@jax.jit
def rk4_step(state: jax.Array, force: jax.Array, dt: float) -> jax.Array:
"""Runge-Kutta 4th order integration step.
Args:
state: Current state [x, theta, x_dot, theta_dot]
force: Control force [f]
dt: Time step
Returns:
Next state after integration
"""
k1 = cart_pole_dynamics(state, force)
k2 = cart_pole_dynamics(state + 0.5 * dt * k1, force)
k3 = cart_pole_dynamics(state + 0.5 * dt * k2, force)
k4 = cart_pole_dynamics(state + dt * k3, force)
return state + (dt / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4)
Forward simulation#
The key difference from direct collocation: we simulate the entire trajectory forward using jax.lax.scan. Given a control sequence, we compute all states by integrating the dynamics.
# Initial and target states.
initial_state = jnp.array([0.0, 0.0, 0.0, 0.0]) # x=0, theta=0 (down), zero velocity
target_state = jnp.array([0.0, jnp.pi, 0.0, 0.0]) # x=0, theta=pi (up), zero velocity
@jax.jit
def simulate_trajectory(controls: jax.Array) -> jax.Array:
"""Simulate the cart-pole system forward given a control sequence.
Args:
controls: Control forces, shape (n_steps, 1)
Returns:
State trajectory, shape (n_steps + 1, 4)
"""
def step(state: jax.Array, control: jax.Array) -> tuple[jax.Array, jax.Array]:
next_state = rk4_step(state, control, dt)
return next_state, next_state
_, trajectory = jax.lax.scan(step, initial_state, controls)
# Prepend initial state.
return jnp.concatenate([initial_state[None, :], trajectory], axis=0)
# Test forward simulation with zero controls.
test_controls = jnp.zeros((n_steps, 1))
test_trajectory = simulate_trajectory(test_controls)
print(f"Trajectory shape: {test_trajectory.shape}")
print(f"Final state with zero controls: {test_trajectory[-1]}")
Trajectory shape: (61, 4)
Final state with zero controls: [0. 0. 0. 0.]
Control variable and costs#
In the shooting method, we optimize a single variable containing the entire control trajectory. The costs include:
Terminal cost: Penalize deviation from upright position at final time
Control regularization: Penalize large control inputs
class ControlTrajectoryVar(
jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((n_steps, 1))
):
"""Control trajectory variable: (n_steps, 1) array of forces."""
control_var = ControlTrajectoryVar(id=0)
@jaxls.Cost.factory
def terminal_cost(
vals: jaxls.VarValues,
var: ControlTrajectoryVar,
target: jax.Array,
) -> jax.Array:
"""Penalize deviation from target state at final time.
Forward simulate the trajectory, then compute error at final state.
"""
controls = vals[var]
trajectory = simulate_trajectory(controls)
final_state = trajectory[-1]
# Weight position and angle errors more heavily.
weights = jnp.array([10.0, 20.0, 5.0, 5.0])
return (final_state - target) * weights
@jaxls.Cost.factory
def hold_cost(
vals: jaxls.VarValues,
var: ControlTrajectoryVar,
target: jax.Array,
) -> jax.Array:
"""Penalize deviation from upright during the hold phase."""
controls = vals[var]
trajectory = simulate_trajectory(controls)
# Get states during hold phase (use static slice with captured n_swing)
hold_states = jax.lax.slice_in_dim(trajectory, n_swing, n_steps + 1, axis=0)
# Penalize deviation from target for each hold state.
weights = jnp.array([5.0, 10.0, 2.0, 2.0])
errors = (hold_states - target) * weights
return errors.flatten()
@jaxls.Cost.factory
def control_regularization(
vals: jaxls.VarValues,
var: ControlTrajectoryVar,
) -> jax.Array:
"""Penalize control effort (L2 regularization)."""
controls = vals[var]
return controls.flatten() * 0.1
Solving#
costs: list[jaxls.Cost] = [
terminal_cost(control_var, target_state),
hold_cost(control_var, target_state),
control_regularization(control_var),
]
print(f"Created {len(costs)} cost functions")
print(f"Control variable shape: ({n_steps}, 1)")
Created 3 cost functions
Control variable shape: (60, 1)
# Initial guess: small random controls to break symmetry.
key = jax.random.PRNGKey(42)
initial_controls = jax.random.normal(key, (n_steps, 1)) * 0.1
initial_vals = jaxls.VarValues.make([control_var.with_value(initial_controls)])
# Create the problem.
problem = jaxls.LeastSquaresProblem(costs, [control_var])
# Visualize the problem structure structure.
problem.show()
# Analyze the problem.
problem = problem.analyze()
INFO | Building optimization problem with 3 terms and 1 variables: 3 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO | Vectorizing group with 1 costs, 1 variables each: terminal_cost
INFO | Vectorizing group with 1 costs, 1 variables each: hold_cost
INFO | Vectorizing group with 1 costs, 1 variables each: control_regularization
solution = problem.solve(initial_vals)
INFO | step #0: cost=14820.0186 lambd=0.0005 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 3938.88013 (avg 984.72003)
INFO | - hold_cost(1): 10881.13477 (avg 247.29852)
INFO | - control_regularization(1): 0.00345 (avg 0.00006)
INFO | accepted=True ATb_norm=5.20e+02 cost_prev=14820.0186 cost_new=10036.3008
INFO | step #1: cost=10036.2979 lambd=0.0003 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 623.13776 (avg 155.78444)
INFO | - hold_cost(1): 9330.73438 (avg 212.06215)
INFO | - control_regularization(1): 82.42606 (avg 1.37377)
INFO | step #2: cost=10036.2979 lambd=0.0005 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 623.13776 (avg 155.78444)
INFO | - hold_cost(1): 9330.73438 (avg 212.06215)
INFO | - control_regularization(1): 82.42606 (avg 1.37377)
INFO | step #3: cost=10036.2979 lambd=0.0010 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 623.13776 (avg 155.78444)
INFO | - hold_cost(1): 9330.73438 (avg 212.06215)
INFO | - control_regularization(1): 82.42606 (avg 1.37377)
INFO | step #4: cost=10036.2979 lambd=0.0020 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 623.13776 (avg 155.78444)
INFO | - hold_cost(1): 9330.73438 (avg 212.06215)
INFO | - control_regularization(1): 82.42606 (avg 1.37377)
INFO | step #5: cost=10036.2979 lambd=0.0040 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 623.13776 (avg 155.78444)
INFO | - hold_cost(1): 9330.73438 (avg 212.06215)
INFO | - control_regularization(1): 82.42606 (avg 1.37377)
INFO | step #6: cost=10036.2979 lambd=0.0080 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 623.13776 (avg 155.78444)
INFO | - hold_cost(1): 9330.73438 (avg 212.06215)
INFO | - control_regularization(1): 82.42606 (avg 1.37377)
INFO | step #7: cost=10036.2979 lambd=0.0160 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 623.13776 (avg 155.78444)
INFO | - hold_cost(1): 9330.73438 (avg 212.06215)
INFO | - control_regularization(1): 82.42606 (avg 1.37377)
INFO | step #8: cost=10036.2979 lambd=0.0320 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 623.13776 (avg 155.78444)
INFO | - hold_cost(1): 9330.73438 (avg 212.06215)
INFO | - control_regularization(1): 82.42606 (avg 1.37377)
INFO | step #9: cost=10036.2979 lambd=0.0640 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 623.13776 (avg 155.78444)
INFO | - hold_cost(1): 9330.73438 (avg 212.06215)
INFO | - control_regularization(1): 82.42606 (avg 1.37377)
INFO | step #10: cost=10036.2979 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 623.13776 (avg 155.78444)
INFO | - hold_cost(1): 9330.73438 (avg 212.06215)
INFO | - control_regularization(1): 82.42606 (avg 1.37377)
INFO | accepted=True ATb_norm=2.84e+03 cost_prev=10036.3008 cost_new=4684.3662
INFO | step #11: cost=4684.3662 lambd=0.0640 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 513.90247 (avg 128.47562)
INFO | - hold_cost(1): 4061.28027 (avg 92.30183)
INFO | - control_regularization(1): 109.18370 (avg 1.81973)
INFO | accepted=True ATb_norm=1.40e+03 cost_prev=4684.3662 cost_new=4028.7876
INFO | step #12: cost=4028.7878 lambd=0.0320 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 970.49084 (avg 242.62271)
INFO | - hold_cost(1): 2844.49316 (avg 64.64758)
INFO | - control_regularization(1): 213.80394 (avg 3.56340)
INFO | step #13: cost=4028.7878 lambd=0.0640 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 970.49084 (avg 242.62271)
INFO | - hold_cost(1): 2844.49316 (avg 64.64758)
INFO | - control_regularization(1): 213.80394 (avg 3.56340)
INFO | step #14: cost=4028.7878 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 970.49084 (avg 242.62271)
INFO | - hold_cost(1): 2844.49316 (avg 64.64758)
INFO | - control_regularization(1): 213.80394 (avg 3.56340)
INFO | step #15: cost=4028.7878 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 970.49084 (avg 242.62271)
INFO | - hold_cost(1): 2844.49316 (avg 64.64758)
INFO | - control_regularization(1): 213.80394 (avg 3.56340)
INFO | accepted=True ATb_norm=1.27e+03 cost_prev=4028.7876 cost_new=3628.3013
INFO | step #16: cost=3628.3010 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 1181.71997 (avg 295.42999)
INFO | - hold_cost(1): 2341.52197 (avg 53.21641)
INFO | - control_regularization(1): 105.05909 (avg 1.75098)
INFO | accepted=True ATb_norm=1.14e+03 cost_prev=3628.3013 cost_new=1633.5022
INFO | step #17: cost=1633.5022 lambd=0.0640 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 345.59445 (avg 86.39861)
INFO | - hold_cost(1): 1256.49683 (avg 28.55675)
INFO | - control_regularization(1): 31.41090 (avg 0.52351)
INFO | step #18: cost=1633.5022 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 345.59445 (avg 86.39861)
INFO | - hold_cost(1): 1256.49683 (avg 28.55675)
INFO | - control_regularization(1): 31.41090 (avg 0.52351)
INFO | step #19: cost=1633.5022 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 345.59445 (avg 86.39861)
INFO | - hold_cost(1): 1256.49683 (avg 28.55675)
INFO | - control_regularization(1): 31.41090 (avg 0.52351)
INFO | step #20: cost=1633.5022 lambd=0.5120 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 345.59445 (avg 86.39861)
INFO | - hold_cost(1): 1256.49683 (avg 28.55675)
INFO | - control_regularization(1): 31.41090 (avg 0.52351)
INFO | step #21: cost=1633.5022 lambd=1.0240 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 345.59445 (avg 86.39861)
INFO | - hold_cost(1): 1256.49683 (avg 28.55675)
INFO | - control_regularization(1): 31.41090 (avg 0.52351)
INFO | step #22: cost=1633.5022 lambd=2.0480 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 345.59445 (avg 86.39861)
INFO | - hold_cost(1): 1256.49683 (avg 28.55675)
INFO | - control_regularization(1): 31.41090 (avg 0.52351)
INFO | step #23: cost=1633.5022 lambd=4.0960 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 345.59445 (avg 86.39861)
INFO | - hold_cost(1): 1256.49683 (avg 28.55675)
INFO | - control_regularization(1): 31.41090 (avg 0.52351)
INFO | accepted=True ATb_norm=3.72e+02 cost_prev=1633.5022 cost_new=1007.1594
INFO | step #24: cost=1007.1594 lambd=2.0480 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 319.02032 (avg 79.75508)
INFO | - hold_cost(1): 653.23999 (avg 14.84636)
INFO | - control_regularization(1): 34.89909 (avg 0.58165)
INFO | step #25: cost=1007.1594 lambd=4.0960 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 319.02032 (avg 79.75508)
INFO | - hold_cost(1): 653.23999 (avg 14.84636)
INFO | - control_regularization(1): 34.89909 (avg 0.58165)
INFO | accepted=True ATb_norm=4.61e+02 cost_prev=1007.1594 cost_new=572.4126
INFO | step #26: cost=572.4127 lambd=2.0480 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 145.04785 (avg 36.26196)
INFO | - hold_cost(1): 390.94324 (avg 8.88507)
INFO | - control_regularization(1): 36.42155 (avg 0.60703)
INFO | step #27: cost=572.4127 lambd=4.0960 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 145.04785 (avg 36.26196)
INFO | - hold_cost(1): 390.94324 (avg 8.88507)
INFO | - control_regularization(1): 36.42155 (avg 0.60703)
INFO | accepted=True ATb_norm=3.69e+02 cost_prev=572.4126 cost_new=340.4744
INFO | step #28: cost=340.4744 lambd=2.0480 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 42.54160 (avg 10.63540)
INFO | - hold_cost(1): 260.97519 (avg 5.93125)
INFO | - control_regularization(1): 36.95764 (avg 0.61596)
INFO | step #29: cost=340.4744 lambd=4.0960 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 42.54160 (avg 10.63540)
INFO | - hold_cost(1): 260.97519 (avg 5.93125)
INFO | - control_regularization(1): 36.95764 (avg 0.61596)
INFO | accepted=True ATb_norm=2.10e+02 cost_prev=340.4744 cost_new=258.7848
INFO | step #30: cost=258.7848 lambd=2.0480 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 17.60720 (avg 4.40180)
INFO | - hold_cost(1): 203.89160 (avg 4.63390)
INFO | - control_regularization(1): 37.28600 (avg 0.62143)
INFO | step #31: cost=258.7848 lambd=4.0960 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 17.60720 (avg 4.40180)
INFO | - hold_cost(1): 203.89160 (avg 4.63390)
INFO | - control_regularization(1): 37.28600 (avg 0.62143)
INFO | accepted=True ATb_norm=1.31e+02 cost_prev=258.7848 cost_new=217.3717
INFO | step #32: cost=217.3717 lambd=2.0480 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 9.49219 (avg 2.37305)
INFO | - hold_cost(1): 170.40802 (avg 3.87291)
INFO | - control_regularization(1): 37.47153 (avg 0.62453)
INFO | accepted=True ATb_norm=8.43e+01 cost_prev=217.3717 cost_new=205.9092
INFO | step #33: cost=205.9092 lambd=1.0240 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 24.00981 (avg 6.00245)
INFO | - hold_cost(1): 144.31160 (avg 3.27981)
INFO | - control_regularization(1): 37.58778 (avg 0.62646)
INFO | step #34: cost=205.9092 lambd=2.0480 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 24.00981 (avg 6.00245)
INFO | - hold_cost(1): 144.31160 (avg 3.27981)
INFO | - control_regularization(1): 37.58778 (avg 0.62646)
INFO | accepted=True ATb_norm=2.04e+02 cost_prev=205.9092 cost_new=155.5018
INFO | step #35: cost=155.5017 lambd=1.0240 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 7.42080 (avg 1.85520)
INFO | - hold_cost(1): 110.72858 (avg 2.51656)
INFO | - control_regularization(1): 37.35236 (avg 0.62254)
INFO | accepted=True ATb_norm=9.93e+01 cost_prev=155.5018 cost_new=150.3035
INFO | step #36: cost=150.3035 lambd=0.5120 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 21.09509 (avg 5.27377)
INFO | - hold_cost(1): 92.15547 (avg 2.09444)
INFO | - control_regularization(1): 37.05294 (avg 0.61755)
INFO | step #37: cost=150.3035 lambd=1.0240 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 21.09509 (avg 5.27377)
INFO | - hold_cost(1): 92.15547 (avg 2.09444)
INFO | - control_regularization(1): 37.05294 (avg 0.61755)
INFO | accepted=True ATb_norm=2.25e+02 cost_prev=150.3035 cost_new=107.0154
INFO | step #38: cost=107.0154 lambd=0.5120 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 5.88594 (avg 1.47149)
INFO | - hold_cost(1): 64.37170 (avg 1.46299)
INFO | - control_regularization(1): 36.75779 (avg 0.61263)
INFO | step #39: cost=107.0154 lambd=1.0240 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 5.88594 (avg 1.47149)
INFO | - hold_cost(1): 64.37170 (avg 1.46299)
INFO | - control_regularization(1): 36.75779 (avg 0.61263)
INFO | accepted=True ATb_norm=1.15e+02 cost_prev=107.0154 cost_new=88.8110
INFO | step #40: cost=88.8110 lambd=0.5120 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 2.51078 (avg 0.62770)
INFO | - hold_cost(1): 49.79275 (avg 1.13165)
INFO | - control_regularization(1): 36.50744 (avg 0.60846)
INFO | accepted=True ATb_norm=6.90e+01 cost_prev=88.8110 cost_new=83.8815
INFO | step #41: cost=83.8814 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 9.01481 (avg 2.25370)
INFO | - hold_cost(1): 38.66769 (avg 0.87881)
INFO | - control_regularization(1): 36.19894 (avg 0.60332)
INFO | step #42: cost=83.8814 lambd=0.5120 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 9.01481 (avg 2.25370)
INFO | - hold_cost(1): 38.66769 (avg 0.87881)
INFO | - control_regularization(1): 36.19894 (avg 0.60332)
INFO | accepted=True ATb_norm=1.68e+02 cost_prev=83.8815 cost_new=64.7548
INFO | step #43: cost=64.7548 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 3.24689 (avg 0.81172)
INFO | - hold_cost(1): 25.47734 (avg 0.57903)
INFO | - control_regularization(1): 36.03053 (avg 0.60051)
INFO | step #44: cost=64.7548 lambd=0.5120 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 3.24689 (avg 0.81172)
INFO | - hold_cost(1): 25.47734 (avg 0.57903)
INFO | - control_regularization(1): 36.03053 (avg 0.60051)
INFO | accepted=True ATb_norm=1.01e+02 cost_prev=64.7548 cost_new=55.5099
INFO | step #45: cost=55.5099 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 1.41236 (avg 0.35309)
INFO | - hold_cost(1): 18.26465 (avg 0.41511)
INFO | - control_regularization(1): 35.83292 (avg 0.59722)
INFO | accepted=True ATb_norm=6.48e+01 cost_prev=55.5099 cost_new=55.2921
INFO | step #46: cost=55.2921 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 5.59101 (avg 1.39775)
INFO | - hold_cost(1): 14.36527 (avg 0.32648)
INFO | - control_regularization(1): 35.33578 (avg 0.58893)
INFO | step #47: cost=55.2921 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 5.59101 (avg 1.39775)
INFO | - hold_cost(1): 14.36527 (avg 0.32648)
INFO | - control_regularization(1): 35.33578 (avg 0.58893)
INFO | accepted=True ATb_norm=1.48e+02 cost_prev=55.2921 cost_new=45.7453
INFO | step #48: cost=45.7453 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 2.02972 (avg 0.50743)
INFO | - hold_cost(1): 8.95305 (avg 0.20348)
INFO | - control_regularization(1): 34.76252 (avg 0.57938)
INFO | step #49: cost=45.7453 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 2.02972 (avg 0.50743)
INFO | - hold_cost(1): 8.95305 (avg 0.20348)
INFO | - control_regularization(1): 34.76252 (avg 0.57938)
INFO | accepted=True ATb_norm=8.89e+01 cost_prev=45.7453 cost_new=41.8801
INFO | step #50: cost=41.8801 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 1.12498 (avg 0.28124)
INFO | - hold_cost(1): 6.62376 (avg 0.15054)
INFO | - control_regularization(1): 34.13136 (avg 0.56886)
INFO | step #51: cost=41.8801 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 1.12498 (avg 0.28124)
INFO | - hold_cost(1): 6.62376 (avg 0.15054)
INFO | - control_regularization(1): 34.13136 (avg 0.56886)
INFO | accepted=True ATb_norm=6.54e+01 cost_prev=41.8801 cost_new=39.5404
INFO | step #52: cost=39.5404 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 0.73957 (avg 0.18489)
INFO | - hold_cost(1): 5.34821 (avg 0.12155)
INFO | - control_regularization(1): 33.45261 (avg 0.55754)
INFO | step #53: cost=39.5404 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 0.73957 (avg 0.18489)
INFO | - hold_cost(1): 5.34821 (avg 0.12155)
INFO | - control_regularization(1): 33.45261 (avg 0.55754)
INFO | accepted=True ATb_norm=5.19e+01 cost_prev=39.5404 cost_new=37.9006
INFO | step #54: cost=37.9006 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 0.55865 (avg 0.13966)
INFO | - hold_cost(1): 4.59287 (avg 0.10438)
INFO | - control_regularization(1): 32.74909 (avg 0.54582)
INFO | accepted=True ATb_norm=4.43e+01 cost_prev=37.9006 cost_new=40.9190
INFO | step #55: cost=40.9190 lambd=0.0640 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 3.78697 (avg 0.94674)
INFO | - hold_cost(1): 5.69330 (avg 0.12939)
INFO | - control_regularization(1): 31.43876 (avg 0.52398)
INFO | accepted=True ATb_norm=1.41e+02 cost_prev=40.9190 cost_new=82.7037
INFO | step #56: cost=82.7037 lambd=0.0320 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 31.58258 (avg 7.89565)
INFO | - hold_cost(1): 21.73989 (avg 0.49409)
INFO | - control_regularization(1): 29.38120 (avg 0.48969)
INFO | step #57: cost=82.7037 lambd=0.0640 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 31.58258 (avg 7.89565)
INFO | - hold_cost(1): 21.73989 (avg 0.49409)
INFO | - control_regularization(1): 29.38120 (avg 0.48969)
INFO | accepted=True ATb_norm=4.51e+02 cost_prev=82.7037 cost_new=58.5050
INFO | step #58: cost=58.5050 lambd=0.0320 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 17.82385 (avg 4.45596)
INFO | - hold_cost(1): 13.26104 (avg 0.30139)
INFO | - control_regularization(1): 27.42006 (avg 0.45700)
INFO | step #59: cost=58.5050 lambd=0.0640 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 17.82385 (avg 4.45596)
INFO | - hold_cost(1): 13.26104 (avg 0.30139)
INFO | - control_regularization(1): 27.42006 (avg 0.45700)
INFO | step #60: cost=58.5050 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 17.82385 (avg 4.45596)
INFO | - hold_cost(1): 13.26104 (avg 0.30139)
INFO | - control_regularization(1): 27.42006 (avg 0.45700)
INFO | accepted=True ATb_norm=3.59e+02 cost_prev=58.5050 cost_new=31.9685
INFO | step #61: cost=31.9685 lambd=0.0640 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 1.86539 (avg 0.46635)
INFO | - hold_cost(1): 3.80649 (avg 0.08651)
INFO | - control_regularization(1): 26.29662 (avg 0.43828)
INFO | step #62: cost=31.9685 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 1.86539 (avg 0.46635)
INFO | - hold_cost(1): 3.80649 (avg 0.08651)
INFO | - control_regularization(1): 26.29662 (avg 0.43828)
INFO | step #63: cost=31.9685 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 1.86539 (avg 0.46635)
INFO | - hold_cost(1): 3.80649 (avg 0.08651)
INFO | - control_regularization(1): 26.29662 (avg 0.43828)
INFO | accepted=True ATb_norm=1.12e+02 cost_prev=31.9685 cost_new=29.1282
INFO | step #64: cost=29.1282 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 0.36223 (avg 0.09056)
INFO | - hold_cost(1): 2.98873 (avg 0.06793)
INFO | - control_regularization(1): 25.77723 (avg 0.42962)
INFO | accepted=True ATb_norm=4.02e+01 cost_prev=29.1282 cost_new=32.2756
INFO | step #65: cost=32.2756 lambd=0.0640 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 3.00576 (avg 0.75144)
INFO | - hold_cost(1): 4.37940 (avg 0.09953)
INFO | - control_regularization(1): 24.89044 (avg 0.41484)
INFO | accepted=True ATb_norm=1.51e+02 cost_prev=32.2756 cost_new=70.9422
INFO | step #66: cost=70.9422 lambd=0.0320 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 28.18815 (avg 7.04704)
INFO | - hold_cost(1): 19.29500 (avg 0.43852)
INFO | - control_regularization(1): 23.45909 (avg 0.39098)
INFO | step #67: cost=70.9422 lambd=0.0640 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 28.18815 (avg 7.04704)
INFO | - hold_cost(1): 19.29500 (avg 0.43852)
INFO | - control_regularization(1): 23.45909 (avg 0.39098)
INFO | accepted=True ATb_norm=5.10e+02 cost_prev=70.9422 cost_new=68.8171
INFO | step #68: cost=68.8171 lambd=0.0320 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 27.73076 (avg 6.93269)
INFO | - hold_cost(1): 19.00603 (avg 0.43196)
INFO | - control_regularization(1): 22.08033 (avg 0.36801)
INFO | step #69: cost=68.8171 lambd=0.0640 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 27.73076 (avg 6.93269)
INFO | - hold_cost(1): 19.00603 (avg 0.43196)
INFO | - control_regularization(1): 22.08033 (avg 0.36801)
INFO | step #70: cost=68.8171 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 27.73076 (avg 6.93269)
INFO | - hold_cost(1): 19.00603 (avg 0.43196)
INFO | - control_regularization(1): 22.08033 (avg 0.36801)
INFO | accepted=True ATb_norm=5.35e+02 cost_prev=68.8171 cost_new=28.8338
INFO | step #71: cost=28.8338 lambd=0.0640 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 3.20780 (avg 0.80195)
INFO | - hold_cost(1): 4.35866 (avg 0.09906)
INFO | - control_regularization(1): 21.26731 (avg 0.35446)
INFO | step #72: cost=28.8338 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 3.20780 (avg 0.80195)
INFO | - hold_cost(1): 4.35866 (avg 0.09906)
INFO | - control_regularization(1): 21.26731 (avg 0.35446)
INFO | accepted=True ATb_norm=1.80e+02 cost_prev=28.8338 cost_new=27.9708
INFO | step #73: cost=27.9708 lambd=0.0640 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 3.12860 (avg 0.78215)
INFO | - hold_cost(1): 4.24473 (avg 0.09647)
INFO | - control_regularization(1): 20.59751 (avg 0.34329)
INFO | step #74: cost=27.9708 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 3.12860 (avg 0.78215)
INFO | - hold_cost(1): 4.24473 (avg 0.09647)
INFO | - control_regularization(1): 20.59751 (avg 0.34329)
INFO | step #75: cost=27.9708 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 3.12860 (avg 0.78215)
INFO | - hold_cost(1): 4.24473 (avg 0.09647)
INFO | - control_regularization(1): 20.59751 (avg 0.34329)
INFO | accepted=True ATb_norm=1.82e+02 cost_prev=27.9708 cost_new=23.3691
INFO | step #76: cost=23.3691 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 0.39292 (avg 0.09823)
INFO | - hold_cost(1): 2.72641 (avg 0.06196)
INFO | - control_regularization(1): 20.24974 (avg 0.33750)
INFO | step #77: cost=23.3691 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 0.39292 (avg 0.09823)
INFO | - hold_cost(1): 2.72641 (avg 0.06196)
INFO | - control_regularization(1): 20.24974 (avg 0.33750)
INFO | accepted=True ATb_norm=5.24e+01 cost_prev=23.3691 cost_new=23.0268
INFO | step #78: cost=23.0268 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 0.40640 (avg 0.10160)
INFO | - hold_cost(1): 2.69457 (avg 0.06124)
INFO | - control_regularization(1): 19.92582 (avg 0.33210)
INFO | step #79: cost=23.0268 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 0.40640 (avg 0.10160)
INFO | - hold_cost(1): 2.69457 (avg 0.06124)
INFO | - control_regularization(1): 19.92582 (avg 0.33210)
INFO | accepted=True ATb_norm=5.46e+01 cost_prev=23.0268 cost_new=22.6713
INFO | step #80: cost=22.6713 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 0.40498 (avg 0.10124)
INFO | - hold_cost(1): 2.66189 (avg 0.06050)
INFO | - control_regularization(1): 19.60442 (avg 0.32674)
INFO | accepted=True ATb_norm=5.53e+01 cost_prev=22.6713 cost_new=26.3223
INFO | step #81: cost=26.3223 lambd=0.0640 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 3.17506 (avg 0.79376)
INFO | - hold_cost(1): 4.11187 (avg 0.09345)
INFO | - control_regularization(1): 19.03540 (avg 0.31726)
INFO | step #82: cost=26.3223 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 3.17506 (avg 0.79376)
INFO | - hold_cost(1): 4.11187 (avg 0.09345)
INFO | - control_regularization(1): 19.03540 (avg 0.31726)
INFO | step #83: cost=26.3223 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 3.17506 (avg 0.79376)
INFO | - hold_cost(1): 4.11187 (avg 0.09345)
INFO | - control_regularization(1): 19.03540 (avg 0.31726)
INFO | accepted=True ATb_norm=1.98e+02 cost_prev=26.3223 cost_new=21.6792
INFO | step #84: cost=21.6792 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 0.38231 (avg 0.09558)
INFO | - hold_cost(1): 2.55796 (avg 0.05814)
INFO | - control_regularization(1): 18.73889 (avg 0.31231)
INFO | step #85: cost=21.6792 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 0.38231 (avg 0.09558)
INFO | - hold_cost(1): 2.55796 (avg 0.05814)
INFO | - control_regularization(1): 18.73889 (avg 0.31231)
INFO | accepted=True ATb_norm=5.58e+01 cost_prev=21.6792 cost_new=21.3778
INFO | step #86: cost=21.3778 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 0.39291 (avg 0.09823)
INFO | - hold_cost(1): 2.53188 (avg 0.05754)
INFO | - control_regularization(1): 18.45297 (avg 0.30755)
INFO | accepted=True ATb_norm=5.78e+01 cost_prev=21.3778 cost_new=24.9109
INFO | step #87: cost=24.9109 lambd=0.0640 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 3.04330 (avg 0.76083)
INFO | - hold_cost(1): 3.91643 (avg 0.08901)
INFO | - control_regularization(1): 17.95120 (avg 0.29919)
INFO | step #88: cost=24.9109 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 3.04330 (avg 0.76083)
INFO | - hold_cost(1): 3.91643 (avg 0.08901)
INFO | - control_regularization(1): 17.95120 (avg 0.29919)
INFO | accepted=True ATb_norm=2.05e+02 cost_prev=24.9109 cost_new=24.8942
INFO | step #89: cost=24.8942 lambd=0.0640 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 3.36977 (avg 0.84244)
INFO | - hold_cost(1): 4.03785 (avg 0.09177)
INFO | - control_regularization(1): 17.48660 (avg 0.29144)
INFO | step #90: cost=24.8942 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 3.36977 (avg 0.84244)
INFO | - hold_cost(1): 4.03785 (avg 0.09177)
INFO | - control_regularization(1): 17.48660 (avg 0.29144)
INFO | step #91: cost=24.8942 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 3.36977 (avg 0.84244)
INFO | - hold_cost(1): 4.03785 (avg 0.09177)
INFO | - control_regularization(1): 17.48660 (avg 0.29144)
INFO | accepted=True ATb_norm=2.22e+02 cost_prev=24.8942 cost_new=19.9573
INFO | step #92: cost=19.9573 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 0.35839 (avg 0.08960)
INFO | - hold_cost(1): 2.34783 (avg 0.05336)
INFO | - control_regularization(1): 17.25113 (avg 0.28752)
INFO | step #93: cost=19.9573 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 0.35839 (avg 0.08960)
INFO | - hold_cost(1): 2.34783 (avg 0.05336)
INFO | - control_regularization(1): 17.25113 (avg 0.28752)
INFO | accepted=True ATb_norm=5.85e+01 cost_prev=19.9573 cost_new=19.7023
INFO | step #94: cost=19.7023 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 0.36276 (avg 0.09069)
INFO | - hold_cost(1): 2.31201 (avg 0.05255)
INFO | - control_regularization(1): 17.02754 (avg 0.28379)
INFO | accepted=True ATb_norm=6.01e+01 cost_prev=19.7023 cost_new=22.8621
INFO | step #95: cost=22.8621 lambd=0.0640 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 2.71344 (avg 0.67836)
INFO | - hold_cost(1): 3.53635 (avg 0.08037)
INFO | - control_regularization(1): 16.61228 (avg 0.27687)
INFO | step #96: cost=22.8621 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 2.71344 (avg 0.67836)
INFO | - hold_cost(1): 3.53635 (avg 0.08037)
INFO | - control_regularization(1): 16.61228 (avg 0.27687)
INFO | step #97: cost=22.8621 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 2.71344 (avg 0.67836)
INFO | - hold_cost(1): 3.53635 (avg 0.08037)
INFO | - control_regularization(1): 16.61228 (avg 0.27687)
INFO | accepted=True ATb_norm=2.09e+02 cost_prev=22.8621 cost_new=18.9556
INFO | step #98: cost=18.9556 lambd=0.1280 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 0.33499 (avg 0.08375)
INFO | - hold_cost(1): 2.21210 (avg 0.05028)
INFO | - control_regularization(1): 16.40852 (avg 0.27348)
INFO | step #99: cost=18.9556 lambd=0.2560 inexact_tol=1.0e-02
INFO | - terminal_cost(1): 0.33499 (avg 0.08375)
INFO | - hold_cost(1): 2.21210 (avg 0.05028)
INFO | - control_regularization(1): 16.40852 (avg 0.27348)
INFO | accepted=True ATb_norm=5.92e+01 cost_prev=18.9556 cost_new=18.7348
INFO | Terminated @ iteration #100: cost=18.7348 criteria=[0 0 0], term_deltas=1.2e-02,1.8e+01,2.3e-02
Visualization#
# Extract solution.
optimal_controls = solution[control_var]
optimal_trajectory = simulate_trajectory(optimal_controls)
times = jnp.linspace(0, total_time, n_steps + 1)
control_times = jnp.linspace(0, total_time - dt, n_steps)
print(
f"Final state: x={float(optimal_trajectory[-1, 0]):.3f}, theta={float(optimal_trajectory[-1, 1]):.3f}"
)
print(f"Target theta (pi): {jnp.pi:.3f}")
print(f"Terminal angle error: {abs(float(optimal_trajectory[-1, 1]) - jnp.pi):.4f} rad")
Final state: x=-0.003, theta=3.153
Target theta (pi): 3.142
Terminal angle error: 0.0112 rad
Comparison with direct collocation#
The shooting method optimizes only the control trajectory (60 decision variables for 60 timesteps), while direct collocation in Cart-pole (collocation) optimizes both states and controls (305 decision variables: 61 states x 4 + 60 controls).
Shooting method advantages:
Fewer decision variables
Dynamics are satisfied by construction (no defect constraints)
Simpler problem formulation
Shooting method disadvantages:
Sensitivity to initial conditions grows exponentially with time horizon
Harder to add state constraints (must be handled through cost)
Can struggle with long horizons or unstable dynamics
For this cart-pole problem, both methods work well because the time horizon is short. For longer horizons or more complex systems, direct collocation often provides better convergence.
For solver configuration options, see jaxls.TrustRegionConfig and jaxls.TerminationConfig.