Inverse kinematics#

In this notebook, we solve an inverse kinematics problem: finding joint angles for a 7-DOF Franka Panda arm to reach target poses.

Note

This example demonstrates basic IK with jaxls. For kinematic optimization in practice, see PyRoki, which provides a more complete interface for robot kinematics problems.

Features used:

  • Var for joint angle variables with batched IDs

  • @jaxls.Cost.factory for end-effector constraints and costs

  • Equality constraints (constraint_eq_zero): position and orientation targets

  • Inequality constraints (constraint_leq_zero): joint limit enforcement

  • Augmented Lagrangian solver for constrained optimization

  • Forward kinematics with URDF parsing

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import jax
import jax.numpy as jnp
import jaxls
import yourdfpy
from robot_descriptions.panda_description import URDF_PATH

Robot model#

Load the Franka Panda URDF and extract joint parameters:

urdf = yourdfpy.URDF.load(
    URDF_PATH,
    build_collision_scene_graph=True,
    load_collision_meshes=True,
)

# Extract joint information from URDF
joint_names = [f"panda_joint{i + 1}" for i in range(7)]
joint_limits_lower = jnp.array([urdf.joint_map[j].limit.lower for j in joint_names])
joint_limits_upper = jnp.array([urdf.joint_map[j].limit.upper for j in joint_names])

print("Franka Panda: 7-DOF robot arm")
print("Joint limits (rad):")
for i, (lo, hi) in enumerate(zip(joint_limits_lower, joint_limits_upper)):
    print(f"  Joint {i + 1}: [{float(lo):.2f}, {float(hi):.2f}]")
Franka Panda: 7-DOF robot arm
Joint limits (rad):
  Joint 1: [-2.90, 2.90]
  Joint 2: [-1.76, 1.76]
  Joint 3: [-2.90, 2.90]
  Joint 4: [-3.07, -0.07]
  Joint 5: [-2.90, 2.90]
  Joint 6: [-0.02, 3.75]
  Joint 7: [-2.90, 2.90]

Forward kinematics#

Build the kinematic chain from URDF transforms. Each joint applies a rotation about its local z-axis:

Hide code cell source

# Extract fixed transforms from URDF (joint origin frames)
def get_joint_transform(joint_name: str) -> tuple[jax.Array, jax.Array]:
    """Get (rotation, translation) from URDF joint origin."""
    joint = urdf.joint_map[joint_name]
    T = joint.origin if joint.origin is not None else jnp.eye(4)
    return jnp.array(T[:3, :3]), jnp.array(T[:3, 3])


# Precompute fixed transforms for each joint.
joint_transforms = [get_joint_transform(j) for j in joint_names]

# End-effector offset (link7 to link8/flange)
ee_offset = jnp.array([0.0, 0.0, 0.107])


@jax.jit
def rotation_z(theta: jax.Array) -> jax.Array:
    """Rotation matrix about z-axis."""
    c, s = jnp.cos(theta), jnp.sin(theta)
    return jnp.array([[c, -s, 0.0], [s, c, 0.0], [0.0, 0.0, 1.0]])


@jax.jit
def forward_kinematics(q: jax.Array) -> tuple[jax.Array, jax.Array]:
    """Compute end-effector pose from joint angles.

    Args:
        q: Joint angles (7,)

    Returns:
        (position, rotation_matrix): End-effector pose in base frame
    """
    R = jnp.eye(3)
    p = jnp.zeros(3)

    for i in range(7):
        R_fixed, t_fixed = joint_transforms[i]
        # Apply fixed transform then joint rotation.
        p = p + R @ t_fixed
        R = R @ R_fixed @ rotation_z(q[i])

    # Apply end-effector offset.
    p = p + R @ ee_offset
    return p, R


@jax.jit
def forward_kinematics_all_links(q: jax.Array) -> jax.Array:
    """Compute positions of all link origins for visualization.

    Args:
        q: Joint angles (7,)

    Returns:
        positions: (9, 3) array of link positions (base + 7 joints + end-effector)
    """
    positions = [jnp.zeros(3)]
    R = jnp.eye(3)
    p = jnp.zeros(3)

    for i in range(7):
        R_fixed, t_fixed = joint_transforms[i]
        p = p + R @ t_fixed
        R = R @ R_fixed @ rotation_z(q[i])
        positions.append(p)

    # Add end-effector position.
    p_ee = p + R @ ee_offset
    positions.append(p_ee)

    return jnp.stack(positions)


# Test FK with zero configuration.
q_test = jnp.zeros(7)
pos, rot = forward_kinematics(q_test)
print(f"End-effector at q=0: position={pos}")
End-effector at q=0: position=[ 8.800000e-02 -8.939922e-18  9.260000e-01]

Variables and costs#

Define joint angle variables and cost functions for IK:

class JointAnglesVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(7)):
    """7-DOF joint configuration."""


joint_var = JointAnglesVar(id=0)
@jaxls.Cost.factory(kind="constraint_eq_zero")
def position_constraint(
    vals: jaxls.VarValues,
    var: JointAnglesVar,
    target_pos: jax.Array,
) -> jax.Array:
    """End-effector position constraint (hard equality)."""
    pos, _ = forward_kinematics(vals[var])
    return pos - target_pos


@jaxls.Cost.factory(kind="constraint_eq_zero")
def orientation_constraint(
    vals: jaxls.VarValues,
    var: JointAnglesVar,
    target_z_axis: jax.Array,
) -> jax.Array:
    """End-effector orientation constraint (align z-axis with target)."""
    _, R = forward_kinematics(vals[var])
    z_axis = R[:, 2]  # End-effector z-axis
    return z_axis - target_z_axis


@jaxls.Cost.factory
def regularization_cost(
    vals: jaxls.VarValues,
    var: JointAnglesVar,
    q_ref: jax.Array,
) -> jax.Array:
    """Prefer configurations close to reference (resolves redundancy)."""
    return (vals[var] - q_ref) * 0.1


@jaxls.Cost.factory(kind="constraint_leq_zero")
def joint_upper_limit(
    vals: jaxls.VarValues,
    var: JointAnglesVar,
    limits: jax.Array,
) -> jax.Array:
    """Joint angles <= upper limits."""
    return vals[var] - limits


@jaxls.Cost.factory(kind="constraint_leq_zero")
def joint_lower_limit(
    vals: jaxls.VarValues,
    var: JointAnglesVar,
    limits: jax.Array,
) -> jax.Array:
    """Joint angles >= lower limits (as -q + lower <= 0)."""
    return limits - vals[var]

Problem structure#

The IK problem consists of position and orientation constraints, regularization for redundancy resolution, and joint limit constraints. We can visualize this as a problem structure:

# Build an example problem to visualize.
example_target = jnp.array([0.4, 0.0, 0.4])
example_costs = [
    position_constraint(joint_var, example_target),
    orientation_constraint(joint_var, jnp.array([0.0, 0.0, -1.0])),
    regularization_cost(joint_var, jnp.zeros(7)),
    joint_upper_limit(joint_var, joint_limits_upper),
    joint_lower_limit(joint_var, joint_limits_lower),
]

# Visualize the problem structure structure.
jaxls.LeastSquaresProblem(example_costs, [joint_var]).show()

Solve IK for multiple targets#

We’ll compute IK solutions for a sequence of target positions along a circular path:

# Define target trajectory: circle in the x-z plane.
n_targets = 16
center = jnp.array([0.4, 0.0, 0.4])
radius = 0.15
angles = jnp.linspace(0, 2 * jnp.pi, n_targets, endpoint=False)
targets = jnp.stack(
    [
        center[0] + radius * jnp.cos(angles),
        jnp.zeros(n_targets),
        center[2] + radius * jnp.sin(angles),
    ],
    axis=-1,
)

# Desired end-effector orientation: pointing down.
target_z = jnp.array([0.0, 0.0, -1.0])

# Reference configuration (middle of joint ranges)
q_ref = (joint_limits_lower + joint_limits_upper) / 2

print(f"Solving IK for {n_targets} targets along circular path")
print(f"Circle center: {center}, radius: {radius}")
Solving IK for 16 targets along circular path
Circle center: [0.4 0.  0.4], radius: 0.15
@jax.jit
def solve_ik(target_pos: jax.Array, q_init: jax.Array) -> jax.Array:
    """Solve IK for a single target position.

    Args:
        target_pos: Target end-effector position (3,)
        q_init: Initial joint angles guess (7,)

    Returns:
        Optimized joint angles (7,)
    """
    costs = [
        position_constraint(joint_var, target_pos),
        orientation_constraint(joint_var, target_z),
        regularization_cost(joint_var, q_ref),
        joint_upper_limit(joint_var, joint_limits_upper),
        joint_lower_limit(joint_var, joint_limits_lower),
    ]
    problem = jaxls.LeastSquaresProblem(costs, [joint_var]).analyze()
    solution = problem.solve(
        initial_vals=jaxls.VarValues.make([joint_var.with_value(q_init)]),
        verbose=False,
        termination=jaxls.TerminationConfig(cost_tolerance=1e-8),
    )
    return solution[joint_var]


# Solve IK sequentially, using previous solution as initial guess.
solutions = []
q_current = q_ref

for i in range(n_targets):
    q_current = solve_ik(targets[i], q_current)
    solutions.append(q_current)

solutions = jnp.stack(solutions)
print(f"Solved {n_targets} IK problems")

# Verify solutions.
position_errors = []
for i in range(n_targets):
    pos, _ = forward_kinematics(solutions[i])
    position_errors.append(float(jnp.linalg.norm(pos - targets[i])))

print(
    f"Position errors: mean={jnp.mean(jnp.array(position_errors)) * 1000:.2f}mm, max={max(position_errors) * 1000:.2f}mm"
)
INFO     | Building optimization problem with 5 terms and 1 variables: 1 costs, 2 eq_zero, 2 leq_zero, 0 geq_zero
INFO     | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_position_constraint
INFO     | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_orientation_constraint
INFO     | Vectorizing group with 1 costs, 1 variables each: regularization_cost
INFO     | Vectorizing constraint group with 1 constraints (constraint_leq_zero), 1 variables each: augmented_joint_upper_limit
INFO     | Vectorizing constraint group with 1 constraints (constraint_leq_zero), 1 variables each: augmented_joint_lower_limit
Solved 16 IK problems
Position errors: mean=0.00mm, max=0.01mm

Visualization#

3D visualization of the robot arm at different poses along the target trajectory:

Hide code cell source

import contextlib
import io
import numpy as np
import trimesh
import viser

# Create Viser server.
with (
    contextlib.redirect_stdout(io.StringIO()),
    contextlib.redirect_stderr(io.StringIO()),
):
    server = viser.ViserServer(verbose=False)

server.scene.set_up_direction("+z")

# Set initial camera position for a good view of the robot.
server.initial_camera.position = (1.5, -1.0, 0.8)
server.initial_camera.look_at = (0.3, 0.0, 0.4)

# Add ground grid.
server.scene.add_grid(
    "/grid",
    width=2.0,
    height=2.0,
    plane="xy",
    position=(0.3, 0, 0),
    cell_color=(200, 200, 200),
    section_color=(170, 170, 170),
)

# Colors for different robot poses (soft, distinguishable palette).
pose_colors = [
    (70, 130, 180),  # Steel blue
    (60, 179, 113),  # Medium sea green
    (255, 165, 0),  # Orange
    (147, 112, 219),  # Medium purple
]

# Show robot at selected poses along trajectory.
display_indices = [0, 5, 10, 15]

for pose_idx, sol_idx in enumerate(display_indices):
    q = solutions[sol_idx]
    color = pose_colors[pose_idx % len(pose_colors)]

    # Update URDF configuration.
    cfg = {joint_names[i]: float(q[i]) for i in range(7)}
    urdf.update_cfg(cfg)

    # Get collision scene meshes.
    scene = urdf.collision_scene

    for geom_name, geom in scene.geometry.items():
        if isinstance(geom, trimesh.Trimesh):
            transform = scene.graph.get(geom_name)[0]
            mesh = geom.copy()
            mesh.apply_transform(transform)
            mesh.fix_normals()

            server.scene.add_mesh_simple(
                f"/robot_{sol_idx}/{geom_name}",
                vertices=np.array(mesh.vertices),
                faces=np.array(mesh.faces),
                color=color,
                opacity=0.7,
            )

# Add target trajectory as points.
for i, target in enumerate(targets):
    server.scene.add_icosphere(
        f"/targets/{i}",
        radius=0.015,
        position=tuple(float(x) for x in target),
        color=(220, 60, 60),
    )

# Add end-effector trajectory.
ee_positions = jax.vmap(lambda q: forward_kinematics(q)[0])(solutions)
for i in range(len(ee_positions) - 1):
    p0 = ee_positions[i]
    p1 = ee_positions[i + 1]
    # Add small spheres along the trajectory.
    server.scene.add_icosphere(
        f"/ee_traj/{i}",
        radius=0.008,
        position=tuple(float(x) for x in p0),
        color=(100, 150, 220),
    )

server.scene.show(height=500)
╭────── viser (listening *:8082) ───────╮
│             ╷                         │
│   HTTP      │ http://localhost:8082   │
│   Websocket │ ws://localhost:8082     │
│             ╵                         │
╰───────────────────────────────────────╯

Joint trajectories#

The joint angles over the circular path, showing joint limit constraints are satisfied:

Hide code cell source

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import HTML

fig = make_subplots(
    rows=2,
    cols=1,
    subplot_titles=("Joint Angles Along Path", "End-Effector Position Error"),
    vertical_spacing=0.22,
)

# Joint angles.
colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2"]
for j in range(7):
    fig.add_trace(
        go.Scatter(
            x=list(range(n_targets)),
            y=[float(solutions[i, j]) for i in range(n_targets)],
            mode="lines+markers",
            name=f"Joint {j + 1}",
            line=dict(color=colors[j]),
            marker=dict(size=5),
            hovertemplate=f"Joint {j + 1}: %{{y:.3f}} rad<extra></extra>",
        ),
        row=1,
        col=1,
    )

# Add joint limit bands (shaded regions)
for j in range(7):
    fig.add_hline(
        y=float(joint_limits_upper[j]),
        line=dict(color=colors[j], width=1, dash="dot"),
        opacity=0.5,
        row=1,
        col=1,
    )
    fig.add_hline(
        y=float(joint_limits_lower[j]),
        line=dict(color=colors[j], width=1, dash="dot"),
        opacity=0.5,
        row=1,
        col=1,
    )

# Position error.
fig.add_trace(
    go.Scatter(
        x=list(range(n_targets)),
        y=[e * 1000 for e in position_errors],
        mode="lines+markers",
        name="Position Error",
        line=dict(color="crimson"),
        marker=dict(size=6),
        hovertemplate="Error: %{y:.2f} mm<extra></extra>",
        showlegend=False,
    ),
    row=2,
    col=1,
)

fig.update_xaxes(title_text="Target Index", row=1, col=1)
fig.update_yaxes(title_text="Angle (rad)", row=1, col=1)
fig.update_xaxes(title_text="Target Index", row=2, col=1)
fig.update_yaxes(title_text="Error (mm)", row=2, col=1)

fig.update_layout(
    height=500,
    margin=dict(t=40, b=40, l=60, r=40),
    legend=dict(orientation="h", yanchor="bottom", y=-0.2, xanchor="center", x=0.5),
)

HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

The solver finds smooth joint trajectories that satisfy the joint limits (dotted lines) while tracking the circular end-effector path with sub-millimeter accuracy. The redundancy in the 7-DOF arm is resolved by the regularization cost that keeps joints near their reference values.

For solver configuration options, see jaxls.TrustRegionConfig and jaxls.AugmentedLagrangianConfig.