Inverse kinematics#
In this notebook, we solve an inverse kinematics problem: finding joint angles for a 7-DOF Franka Panda arm to reach target poses.
Note
This example demonstrates basic IK with jaxls. For kinematic optimization in practice, see PyRoki, which provides a more complete interface for robot kinematics problems.
Features used:
Varfor joint angle variables with batched IDs@jaxls.Cost.factoryfor end-effector constraints and costsEquality constraints (
constraint_eq_zero): position and orientation targetsInequality constraints (
constraint_leq_zero): joint limit enforcementAugmented Lagrangian solver for constrained optimization
Forward kinematics with URDF parsing
import jax
import jax.numpy as jnp
import jaxls
import yourdfpy
from robot_descriptions.panda_description import URDF_PATH
Robot model#
Load the Franka Panda URDF and extract joint parameters:
urdf = yourdfpy.URDF.load(
URDF_PATH,
build_collision_scene_graph=True,
load_collision_meshes=True,
)
# Extract joint information from URDF
joint_names = [f"panda_joint{i + 1}" for i in range(7)]
joint_limits_lower = jnp.array([urdf.joint_map[j].limit.lower for j in joint_names])
joint_limits_upper = jnp.array([urdf.joint_map[j].limit.upper for j in joint_names])
print("Franka Panda: 7-DOF robot arm")
print("Joint limits (rad):")
for i, (lo, hi) in enumerate(zip(joint_limits_lower, joint_limits_upper)):
print(f" Joint {i + 1}: [{float(lo):.2f}, {float(hi):.2f}]")
Franka Panda: 7-DOF robot arm
Joint limits (rad):
Joint 1: [-2.90, 2.90]
Joint 2: [-1.76, 1.76]
Joint 3: [-2.90, 2.90]
Joint 4: [-3.07, -0.07]
Joint 5: [-2.90, 2.90]
Joint 6: [-0.02, 3.75]
Joint 7: [-2.90, 2.90]
Forward kinematics#
Build the kinematic chain from URDF transforms. Each joint applies a rotation about its local z-axis:
End-effector at q=0: position=[ 8.800000e-02 -8.939922e-18 9.260000e-01]
Variables and costs#
Define joint angle variables and cost functions for IK:
class JointAnglesVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(7)):
"""7-DOF joint configuration."""
joint_var = JointAnglesVar(id=0)
@jaxls.Cost.factory(kind="constraint_eq_zero")
def position_constraint(
vals: jaxls.VarValues,
var: JointAnglesVar,
target_pos: jax.Array,
) -> jax.Array:
"""End-effector position constraint (hard equality)."""
pos, _ = forward_kinematics(vals[var])
return pos - target_pos
@jaxls.Cost.factory(kind="constraint_eq_zero")
def orientation_constraint(
vals: jaxls.VarValues,
var: JointAnglesVar,
target_z_axis: jax.Array,
) -> jax.Array:
"""End-effector orientation constraint (align z-axis with target)."""
_, R = forward_kinematics(vals[var])
z_axis = R[:, 2] # End-effector z-axis
return z_axis - target_z_axis
@jaxls.Cost.factory
def regularization_cost(
vals: jaxls.VarValues,
var: JointAnglesVar,
q_ref: jax.Array,
) -> jax.Array:
"""Prefer configurations close to reference (resolves redundancy)."""
return (vals[var] - q_ref) * 0.1
@jaxls.Cost.factory(kind="constraint_leq_zero")
def joint_upper_limit(
vals: jaxls.VarValues,
var: JointAnglesVar,
limits: jax.Array,
) -> jax.Array:
"""Joint angles <= upper limits."""
return vals[var] - limits
@jaxls.Cost.factory(kind="constraint_leq_zero")
def joint_lower_limit(
vals: jaxls.VarValues,
var: JointAnglesVar,
limits: jax.Array,
) -> jax.Array:
"""Joint angles >= lower limits (as -q + lower <= 0)."""
return limits - vals[var]
Problem structure#
The IK problem consists of position and orientation constraints, regularization for redundancy resolution, and joint limit constraints. We can visualize this as a problem structure:
# Build an example problem to visualize.
example_target = jnp.array([0.4, 0.0, 0.4])
example_costs = [
position_constraint(joint_var, example_target),
orientation_constraint(joint_var, jnp.array([0.0, 0.0, -1.0])),
regularization_cost(joint_var, jnp.zeros(7)),
joint_upper_limit(joint_var, joint_limits_upper),
joint_lower_limit(joint_var, joint_limits_lower),
]
# Visualize the problem structure structure.
jaxls.LeastSquaresProblem(example_costs, [joint_var]).show()
Solve IK for multiple targets#
We’ll compute IK solutions for a sequence of target positions along a circular path:
# Define target trajectory: circle in the x-z plane.
n_targets = 16
center = jnp.array([0.4, 0.0, 0.4])
radius = 0.15
angles = jnp.linspace(0, 2 * jnp.pi, n_targets, endpoint=False)
targets = jnp.stack(
[
center[0] + radius * jnp.cos(angles),
jnp.zeros(n_targets),
center[2] + radius * jnp.sin(angles),
],
axis=-1,
)
# Desired end-effector orientation: pointing down.
target_z = jnp.array([0.0, 0.0, -1.0])
# Reference configuration (middle of joint ranges)
q_ref = (joint_limits_lower + joint_limits_upper) / 2
print(f"Solving IK for {n_targets} targets along circular path")
print(f"Circle center: {center}, radius: {radius}")
Solving IK for 16 targets along circular path
Circle center: [0.4 0. 0.4], radius: 0.15
@jax.jit
def solve_ik(target_pos: jax.Array, q_init: jax.Array) -> jax.Array:
"""Solve IK for a single target position.
Args:
target_pos: Target end-effector position (3,)
q_init: Initial joint angles guess (7,)
Returns:
Optimized joint angles (7,)
"""
costs = [
position_constraint(joint_var, target_pos),
orientation_constraint(joint_var, target_z),
regularization_cost(joint_var, q_ref),
joint_upper_limit(joint_var, joint_limits_upper),
joint_lower_limit(joint_var, joint_limits_lower),
]
problem = jaxls.LeastSquaresProblem(costs, [joint_var]).analyze()
solution = problem.solve(
initial_vals=jaxls.VarValues.make([joint_var.with_value(q_init)]),
verbose=False,
termination=jaxls.TerminationConfig(cost_tolerance=1e-8),
)
return solution[joint_var]
# Solve IK sequentially, using previous solution as initial guess.
solutions = []
q_current = q_ref
for i in range(n_targets):
q_current = solve_ik(targets[i], q_current)
solutions.append(q_current)
solutions = jnp.stack(solutions)
print(f"Solved {n_targets} IK problems")
# Verify solutions.
position_errors = []
for i in range(n_targets):
pos, _ = forward_kinematics(solutions[i])
position_errors.append(float(jnp.linalg.norm(pos - targets[i])))
print(
f"Position errors: mean={jnp.mean(jnp.array(position_errors)) * 1000:.2f}mm, max={max(position_errors) * 1000:.2f}mm"
)
INFO | Building optimization problem with 5 terms and 1 variables: 1 costs, 2 eq_zero, 2 leq_zero, 0 geq_zero
INFO | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_position_constraint
INFO | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_orientation_constraint
INFO | Vectorizing group with 1 costs, 1 variables each: regularization_cost
INFO | Vectorizing constraint group with 1 constraints (constraint_leq_zero), 1 variables each: augmented_joint_upper_limit
INFO | Vectorizing constraint group with 1 constraints (constraint_leq_zero), 1 variables each: augmented_joint_lower_limit
Solved 16 IK problems
Position errors: mean=0.00mm, max=0.01mm
Visualization#
3D visualization of the robot arm at different poses along the target trajectory:
╭────── viser (listening *:8082) ───────╮ │ ╷ │ │ HTTP │ http://localhost:8082 │ │ Websocket │ ws://localhost:8082 │ │ ╵ │ ╰───────────────────────────────────────╯
Joint trajectories#
The joint angles over the circular path, showing joint limit constraints are satisfied:
The solver finds smooth joint trajectories that satisfy the joint limits (dotted lines) while tracking the circular end-effector path with sub-millimeter accuracy. The redundancy in the 7-DOF arm is resolved by the regularization cost that keeps joints near their reference values.
For solver configuration options, see jaxls.TrustRegionConfig and jaxls.AugmentedLagrangianConfig.