# Cart-pole (shooting)

In this notebook, we solve a cart-pole swing-up problem using single shooting: finding
control inputs that swing a pendulum from hanging down to balancing upright.

```{note}
This notebook demonstrates the shooting method, where we optimize only the control trajectory
and compute states by forward simulation. For contrast, see {doc}`cart_pole_collocation`, which uses
direct collocation to optimize both states and controls with dynamics as constraints.
```

Key difference from direct collocation:
- Shooting: Optimizes only controls; states are computed by simulating dynamics forward
- Direct collocation: Optimizes both states and controls, with dynamics enforced as equality constraints

Shooting is simpler but can struggle with long time horizons due to sensitivity of the initial value problem.

Features used:
- {class}`~jaxls.Var` for control trajectory (single variable containing all controls)
- {func}`@jaxls.Cost.factory <jaxls.Cost.factory>` for terminal state cost and control regularization
- Forward simulation with `jax.lax.scan`

In [None]:
import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");

In [2]:
import jax
import jax.numpy as jnp
import jaxls

## Cart-pole dynamics

The cart-pole system has state $[x, \theta, \dot{x}, \dot{\theta}]$ where $x$ is cart position and $\theta$ is pole angle (0 = hanging down, $\pi$ = upright). The control input is a horizontal force on the cart.

In [3]:
# Physical parameters.
m_cart = 1.0  # Cart mass (kg)
m_pole = 0.1  # Pole mass (kg)
length = 0.5  # Pole half-length (m)
gravity = 9.81  # Gravitational acceleration (m/s^2)

# Trajectory parameters.
n_swing = 50  # Timesteps for swing-up phase
n_hold = 10  # Timesteps to hold upright position
n_steps = n_swing + n_hold  # Total timesteps
dt = 0.05  # Time step (s)
total_time = n_steps * dt
print(f"Total trajectory time: {total_time:.2f}s ({n_swing} swing + {n_hold} hold)")

Total trajectory time: 3.00s (50 swing + 10 hold)


In [4]:
@jax.jit
def cart_pole_dynamics(state: jax.Array, force: jax.Array) -> jax.Array:
    """Compute state derivatives for the cart-pole system.

    State: [x, theta, x_dot, theta_dot]
    theta = 0: pole hanging down, theta = pi: pole upright

    Args:
        state: Current state [x, theta, x_dot, theta_dot]
        force: Control force [f]

    Returns:
        State derivatives [x_dot, theta_dot, x_ddot, theta_ddot]
    """
    x, theta, x_dot, theta_dot = state
    f = force[0]

    sin_th = jnp.sin(theta)
    cos_th = jnp.cos(theta)

    # Total mass.
    total_mass = m_cart + m_pole
    pole_mass_length = m_pole * length

    # Equations of motion for theta=0 being DOWN (stable equilibrium)
    temp = (f + pole_mass_length * theta_dot**2 * sin_th) / total_mass
    theta_ddot = (-gravity * sin_th - cos_th * temp) / (
        length * (4.0 / 3.0 - m_pole * cos_th**2 / total_mass)
    )
    x_ddot = temp - pole_mass_length * theta_ddot * cos_th / total_mass

    return jnp.array([x_dot, theta_dot, x_ddot, theta_ddot])


@jax.jit
def rk4_step(state: jax.Array, force: jax.Array, dt: float) -> jax.Array:
    """Runge-Kutta 4th order integration step.

    Args:
        state: Current state [x, theta, x_dot, theta_dot]
        force: Control force [f]
        dt: Time step

    Returns:
        Next state after integration
    """
    k1 = cart_pole_dynamics(state, force)
    k2 = cart_pole_dynamics(state + 0.5 * dt * k1, force)
    k3 = cart_pole_dynamics(state + 0.5 * dt * k2, force)
    k4 = cart_pole_dynamics(state + dt * k3, force)
    return state + (dt / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4)

## Forward simulation

The key difference from direct collocation: we simulate the entire trajectory forward using `jax.lax.scan`. Given a control sequence, we compute all states by integrating the dynamics.

In [5]:
# Initial and target states.
initial_state = jnp.array([0.0, 0.0, 0.0, 0.0])  # x=0, theta=0 (down), zero velocity
target_state = jnp.array([0.0, jnp.pi, 0.0, 0.0])  # x=0, theta=pi (up), zero velocity


@jax.jit
def simulate_trajectory(controls: jax.Array) -> jax.Array:
    """Simulate the cart-pole system forward given a control sequence.

    Args:
        controls: Control forces, shape (n_steps, 1)

    Returns:
        State trajectory, shape (n_steps + 1, 4)
    """

    def step(state: jax.Array, control: jax.Array) -> tuple[jax.Array, jax.Array]:
        next_state = rk4_step(state, control, dt)
        return next_state, next_state

    _, trajectory = jax.lax.scan(step, initial_state, controls)
    # Prepend initial state.
    return jnp.concatenate([initial_state[None, :], trajectory], axis=0)


# Test forward simulation with zero controls.
test_controls = jnp.zeros((n_steps, 1))
test_trajectory = simulate_trajectory(test_controls)
print(f"Trajectory shape: {test_trajectory.shape}")
print(f"Final state with zero controls: {test_trajectory[-1]}")

Trajectory shape: (61, 4)
Final state with zero controls: [0. 0. 0. 0.]


## Control variable and costs

In the shooting method, we optimize a single variable containing the entire control trajectory. The costs include:
1. Terminal cost: Penalize deviation from upright position at final time
2. Control regularization: Penalize large control inputs

In [6]:
class ControlTrajectoryVar(
    jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((n_steps, 1))
):
    """Control trajectory variable: (n_steps, 1) array of forces."""


control_var = ControlTrajectoryVar(id=0)

In [7]:
@jaxls.Cost.factory
def terminal_cost(
    vals: jaxls.VarValues,
    var: ControlTrajectoryVar,
    target: jax.Array,
) -> jax.Array:
    """Penalize deviation from target state at final time.

    Forward simulate the trajectory, then compute error at final state.
    """
    controls = vals[var]
    trajectory = simulate_trajectory(controls)
    final_state = trajectory[-1]
    # Weight position and angle errors more heavily.
    weights = jnp.array([10.0, 20.0, 5.0, 5.0])
    return (final_state - target) * weights


@jaxls.Cost.factory
def hold_cost(
    vals: jaxls.VarValues,
    var: ControlTrajectoryVar,
    target: jax.Array,
) -> jax.Array:
    """Penalize deviation from upright during the hold phase."""
    controls = vals[var]
    trajectory = simulate_trajectory(controls)
    # Get states during hold phase (use static slice with captured n_swing)
    hold_states = jax.lax.slice_in_dim(trajectory, n_swing, n_steps + 1, axis=0)
    # Penalize deviation from target for each hold state.
    weights = jnp.array([5.0, 10.0, 2.0, 2.0])
    errors = (hold_states - target) * weights
    return errors.flatten()


@jaxls.Cost.factory
def control_regularization(
    vals: jaxls.VarValues,
    var: ControlTrajectoryVar,
) -> jax.Array:
    """Penalize control effort (L2 regularization)."""
    controls = vals[var]
    return controls.flatten() * 0.1

## Solving

In [8]:
costs: list[jaxls.Cost] = [
    terminal_cost(control_var, target_state),
    hold_cost(control_var, target_state),
    control_regularization(control_var),
]

print(f"Created {len(costs)} cost functions")
print(f"Control variable shape: ({n_steps}, 1)")

Created 3 cost functions
Control variable shape: (60, 1)


In [None]:
# Initial guess: small random controls to break symmetry.
key = jax.random.PRNGKey(42)
initial_controls = jax.random.normal(key, (n_steps, 1)) * 0.1

initial_vals = jaxls.VarValues.make([control_var.with_value(initial_controls)])

# Create the problem.
problem = jaxls.LeastSquaresProblem(costs, [control_var])

# Visualize the problem structure structure.
problem.show()

In [None]:
# Analyze the problem.
problem = problem.analyze()

In [10]:
solution = problem.solve(initial_vals)

[1mINFO    [0m |  step #1: cost=14820.0186 lambd=0.0005 inexact_tol=1.0e-02
[1mINFO    [0m |      - terminal_cost(1): 3938.88013 (avg 984.72003)
[1mINFO    [0m |      - hold_cost(1):   10881.13477 (avg 247.29852)
[1mINFO    [0m |      - control_regularization(1): 0.00345 (avg 0.00006)
[1mINFO    [0m |      accepted=True ATb_norm=5.20e+02 cost_prev=14820.0186 cost_new=10036.3008
[1mINFO    [0m |  step #2: cost=10036.2979 lambd=0.0003 inexact_tol=1.0e-02
[1mINFO    [0m |      - terminal_cost(1): 623.13776 (avg 155.78444)
[1mINFO    [0m |      - hold_cost(1):   9330.73438 (avg 212.06215)
[1mINFO    [0m |      - control_regularization(1): 82.42606 (avg 1.37377)
[1mINFO    [0m |  step #3: cost=10036.2979 lambd=0.0005 inexact_tol=1.0e-02
[1mINFO    [0m |      - terminal_cost(1): 623.13776 (avg 155.78444)
[1mINFO    [0m |      - hold_cost(1):   9330.73438 (avg 212.06215)
[1mINFO    [0m |      - control_regularization(1): 82.42606 (avg 1.37377)
[1mINFO    [0m |  ste

## Visualization

In [11]:
# Extract solution.
optimal_controls = solution[control_var]
optimal_trajectory = simulate_trajectory(optimal_controls)

times = jnp.linspace(0, total_time, n_steps + 1)
control_times = jnp.linspace(0, total_time - dt, n_steps)

print(
    f"Final state: x={float(optimal_trajectory[-1, 0]):.3f}, theta={float(optimal_trajectory[-1, 1]):.3f}"
)
print(f"Target theta (pi): {jnp.pi:.3f}")
print(f"Terminal angle error: {abs(float(optimal_trajectory[-1, 1]) - jnp.pi):.4f} rad")

Final state: x=-0.003, theta=3.153
Target theta (pi): 3.142
Terminal angle error: 0.0117 rad


In [12]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import HTML

# Animation of cart-pole swing-up.
cart_y_offset = 1.2  # Raise cart so pole can swing freely above ground.
wheel_radius = 0.05


def create_cart_pole_frame(x: float, theta: float) -> dict:
    """Create a single frame of cart-pole visualization.

    Args:
        x: Cart position.
        theta: Pole angle (0 = down, pi = up).

    Returns:
        Dictionary with cart, pole, mass, and wheel coordinates.
    """
    cart_width, cart_height = 0.4, 0.15
    pole_len = 2 * length  # Full pole length.
    pole_x = x + pole_len * jnp.sin(theta)
    pole_y = cart_y_offset - pole_len * jnp.cos(theta)

    # Wheel positions.
    wheel_y = cart_y_offset - cart_height / 2 - wheel_radius

    return dict(
        cart_x=[
            x - cart_width / 2,
            x + cart_width / 2,
            x + cart_width / 2,
            x - cart_width / 2,
            x - cart_width / 2,
        ],
        cart_y=[
            cart_y_offset - cart_height / 2,
            cart_y_offset - cart_height / 2,
            cart_y_offset + cart_height / 2,
            cart_y_offset + cart_height / 2,
            cart_y_offset - cart_height / 2,
        ],
        pole_x=[x, float(pole_x)],
        pole_y=[cart_y_offset, float(pole_y)],
        mass_x=[float(pole_x)],
        mass_y=[float(pole_y)],
        wheel1_x=[x - cart_width / 3],
        wheel1_y=[wheel_y],
        wheel2_x=[x + cart_width / 3],
        wheel2_y=[wheel_y],
    )


# Create animation frames (use every 2nd timestep for performance).
frame_step = 2
frame_indices = list(range(0, n_steps + 1, frame_step))
frames = []
for i in frame_indices:
    frame_data = create_cart_pole_frame(
        float(optimal_trajectory[i, 0]), float(optimal_trajectory[i, 1])
    )
    frames.append(
        go.Frame(
            data=[
                go.Scatter(
                    x=frame_data["cart_x"],
                    y=frame_data["cart_y"],
                    fill="toself",
                    fillcolor="steelblue",
                    line=dict(color="darkblue", width=2),
                ),
                go.Scatter(
                    x=frame_data["pole_x"],
                    y=frame_data["pole_y"],
                    mode="lines",
                    line=dict(color="sienna", width=8),
                ),
                go.Scatter(
                    x=frame_data["mass_x"],
                    y=frame_data["mass_y"],
                    mode="markers",
                    marker=dict(
                        size=18, color="sienna", line=dict(color="black", width=1)
                    ),
                ),
                go.Scatter(
                    x=frame_data["wheel1_x"],
                    y=frame_data["wheel1_y"],
                    mode="markers",
                    marker=dict(size=12, color="black", symbol="circle"),
                ),
                go.Scatter(
                    x=frame_data["wheel2_x"],
                    y=frame_data["wheel2_y"],
                    mode="markers",
                    marker=dict(size=12, color="black", symbol="circle"),
                ),
            ],
            name=str(i),
        )
    )

# Initial frame.
init_frame = create_cart_pole_frame(
    float(optimal_trajectory[0, 0]), float(optimal_trajectory[0, 1])
)
rail_y = cart_y_offset - 0.15 / 2 - wheel_radius * 2

fig_anim = go.Figure(
    data=[
        go.Scatter(
            x=init_frame["cart_x"],
            y=init_frame["cart_y"],
            fill="toself",
            fillcolor="steelblue",
            line=dict(color="darkblue", width=2),
            name="Cart",
        ),
        go.Scatter(
            x=init_frame["pole_x"],
            y=init_frame["pole_y"],
            mode="lines",
            line=dict(color="sienna", width=8),
            name="Pole",
        ),
        go.Scatter(
            x=init_frame["mass_x"],
            y=init_frame["mass_y"],
            mode="markers",
            marker=dict(size=18, color="sienna", line=dict(color="black", width=1)),
            showlegend=False,
        ),
        go.Scatter(
            x=init_frame["wheel1_x"],
            y=init_frame["wheel1_y"],
            mode="markers",
            marker=dict(size=12, color="black", symbol="circle"),
            showlegend=False,
        ),
        go.Scatter(
            x=init_frame["wheel2_x"],
            y=init_frame["wheel2_y"],
            mode="markers",
            marker=dict(size=12, color="black", symbol="circle"),
            showlegend=False,
        ),
        go.Scatter(
            x=[-2.5, 2.5],
            y=[rail_y, rail_y],
            mode="lines",
            line=dict(color="gray", width=6),
            name="Rail",
        ),
    ],
    frames=frames,
    layout=go.Layout(
        title="Cart-Pole (Shooting)",
        xaxis=dict(
            range=[-2.5, 2.5], title="x (m)", constrain="domain", showgrid=False
        ),
        yaxis=dict(
            range=[-0.1, 2.5],
            title="y (m)",
            scaleanchor="x",
            scaleratio=1,
            showgrid=False,
        ),
        updatemenus=[
            dict(
                type="buttons",
                showactive=False,
                y=1.15,
                x=0.5,
                xanchor="center",
                buttons=[
                    dict(
                        label="Play",
                        method="animate",
                        args=[
                            None,
                            dict(
                                frame=dict(duration=50, redraw=True),
                                fromcurrent=True,
                                transition=dict(duration=0),
                            ),
                        ],
                    ),
                    dict(
                        label="Pause",
                        method="animate",
                        args=[
                            [None],
                            dict(
                                frame=dict(duration=0, redraw=False),
                                mode="immediate",
                            ),
                        ],
                    ),
                ],
            )
        ],
        sliders=[
            dict(
                active=0,
                yanchor="top",
                xanchor="left",
                currentvalue=dict(
                    prefix="Time: ",
                    suffix="s",
                    visible=True,
                    xanchor="center",
                    offset=20,
                ),
                pad=dict(b=10, t=60),
                steps=[
                    dict(
                        args=[
                            [str(i)],
                            dict(
                                frame=dict(duration=0, redraw=True),
                                mode="immediate",
                                transition=dict(duration=0),
                            ),
                        ],
                        label=f"{float(times[i]):.2f}",
                        method="animate",
                    )
                    for i in frame_indices
                ],
                x=0.1,
                y=0,
                len=0.8,
            )
        ],
        height=500,
        showlegend=False,
        margin=dict(t=80, b=100),
        plot_bgcolor="white",
    ),
)
HTML(fig_anim.to_html(full_html=False, include_plotlyjs="cdn", auto_play=False))

In [13]:
# State and control trajectories.
fig_traj = make_subplots(
    rows=2,
    cols=2,
    subplot_titles=("Cart Position", "Pole Angle", "Velocities", "Control Force"),
    vertical_spacing=0.15,
    horizontal_spacing=0.1,
)

# Cart position.
fig_traj.add_trace(
    go.Scatter(
        x=[float(t) for t in times],
        y=[float(s) for s in optimal_trajectory[:, 0]],
        mode="lines",
        line=dict(color="steelblue", width=2),
        name="x",
    ),
    row=1,
    col=1,
)

# Pole angle.
fig_traj.add_trace(
    go.Scatter(
        x=[float(t) for t in times],
        y=[float(s) for s in optimal_trajectory[:, 1]],
        mode="lines",
        line=dict(color="coral", width=2),
        name="theta",
    ),
    row=1,
    col=2,
)
fig_traj.add_hline(
    y=float(jnp.pi),
    line_dash="dash",
    line_color="gray",
    row=1,
    col=2,
    annotation_text="pi",
)

# Velocities.
fig_traj.add_trace(
    go.Scatter(
        x=[float(t) for t in times],
        y=[float(s) for s in optimal_trajectory[:, 2]],
        mode="lines",
        line=dict(color="steelblue", width=2),
        name="x_dot",
    ),
    row=2,
    col=1,
)
fig_traj.add_trace(
    go.Scatter(
        x=[float(t) for t in times],
        y=[float(s) for s in optimal_trajectory[:, 3]],
        mode="lines",
        line=dict(color="coral", width=2, dash="dash"),
        name="theta_dot",
    ),
    row=2,
    col=1,
)

# Control force.
fig_traj.add_trace(
    go.Scatter(
        x=[float(t) for t in control_times],
        y=[float(u) for u in optimal_controls[:, 0]],
        mode="lines",
        line=dict(color="forestgreen", width=2),
        name="F",
        fill="tozeroy",
        fillcolor="rgba(34, 139, 34, 0.2)",
    ),
    row=2,
    col=2,
)

# Axis labels.
fig_traj.update_xaxes(title_text="Time (s)", row=1, col=1)
fig_traj.update_yaxes(title_text="Position (m)", row=1, col=1)
fig_traj.update_xaxes(title_text="Time (s)", row=1, col=2)
fig_traj.update_yaxes(title_text="Angle (rad)", row=1, col=2)
fig_traj.update_xaxes(title_text="Time (s)", row=2, col=1)
fig_traj.update_yaxes(title_text="Velocity", row=2, col=1)
fig_traj.update_xaxes(title_text="Time (s)", row=2, col=2)
fig_traj.update_yaxes(title_text="Force (N)", row=2, col=2)

fig_traj.update_layout(height=500, showlegend=False, margin=dict(t=40, b=40))
HTML(fig_traj.to_html(full_html=False, include_plotlyjs="cdn"))

## Comparison with direct collocation

The shooting method optimizes only the control trajectory (60 decision variables for 60 timesteps), while direct collocation in {doc}`cart_pole_collocation` optimizes both states and controls (305 decision variables: 61 states x 4 + 60 controls).

Shooting method advantages:
- Fewer decision variables
- Dynamics are satisfied by construction (no defect constraints)
- Simpler problem formulation

Shooting method disadvantages:
- Sensitivity to initial conditions grows exponentially with time horizon
- Harder to add state constraints (must be handled through cost)
- Can struggle with long horizons or unstable dynamics

For this cart-pole problem, both methods work well because the time horizon is short. For longer horizons or more complex systems, direct collocation often provides better convergence.

For solver configuration options, see {class}`jaxls.TrustRegionConfig` and {class}`jaxls.TerminationConfig`.