Cam profile optimization#
In this notebook, we solve a cam design problem: finding cam profiles that produce smooth follower motion while satisfying curvature constraints.
Inputs: Target dwell-rise-dwell-return displacement profile
Outputs: B-spline control points defining the cam shape
Features used:
Varfor B-spline control point radii@jaxls.Cost.factoryfor motion tracking and smoothnessInequality constraints for pressure angle and minimum radius limits
import jax
import jax.numpy as jnp
import jaxls
import numpy as np
Cam geometry#
We model a radial plate cam with a translating knife-edge follower. The cam profile is defined by radii \(r(\theta)\) at discrete angles. The follower displacement equals \(r(\theta) - r_{base}\) where \(r_{base}\) is the base circle radius.
The pressure angle \(\alpha\) is the angle between the follower motion direction and the normal to the cam surface. Large pressure angles cause high side loads on the follower guide:
We use a B-spline parameterization for smooth, manufacturable cam profiles.
def cubic_bspline_basis(t: jax.Array, i: int, n_ctrl: int) -> jax.Array:
"""Evaluate cubic B-spline basis function i at parameter t.
Uses uniform periodic knots for a closed curve.
"""
# Map t to [0, n] with wrapping
u = t * n_ctrl
# Local parameter within the basis function support
local_u = (u - i) % n_ctrl
# Cubic B-spline basis (Cox-de Boor for uniform knots)
# Only non-zero in [0, 4)
basis = jnp.where(
local_u < 1,
local_u**3 / 6,
jnp.where(
local_u < 2,
(-3 * (local_u - 1) ** 3 + 3 * (local_u - 1) ** 2 + 3 * (local_u - 1) + 1)
/ 6,
jnp.where(
local_u < 3,
(3 * (local_u - 2) ** 3 - 6 * (local_u - 2) ** 2 + 4) / 6,
jnp.where(local_u < 4, (1 - (local_u - 3)) ** 3 / 6, 0.0),
),
),
)
return basis
@jax.jit
def eval_cam_radius(ctrl_points: jax.Array, theta: jax.Array) -> jax.Array:
"""Evaluate cam radius at angle theta using B-spline interpolation.
Args:
ctrl_points: B-spline control point radii (n_ctrl,)
theta: Cam angle in [0, 2*pi]
Returns:
Cam radius at theta
"""
n_ctrl = len(ctrl_points)
t = theta / (2 * jnp.pi) # Normalize to [0, 1]
# Sum basis functions weighted by control points
radius = sum(
ctrl_points[i] * cubic_bspline_basis(t, i, n_ctrl) for i in range(n_ctrl)
)
return radius
def cam_profile_xy(ctrl_points: jax.Array, n_points: int = 360) -> jax.Array:
"""Generate cam profile as (x, y) points.
Args:
ctrl_points: B-spline control point radii
n_points: Number of points around the profile
Returns:
(n_points, 2) array of (x, y) coordinates
"""
thetas = jnp.linspace(0, 2 * jnp.pi, n_points, endpoint=False)
radii = jax.vmap(lambda th: eval_cam_radius(ctrl_points, th))(thetas)
x = radii * jnp.cos(thetas)
y = radii * jnp.sin(thetas)
return jnp.stack([x, y], axis=-1)
@jax.jit
def compute_pressure_angle(ctrl_points: jax.Array, theta: jax.Array) -> jax.Array:
"""Compute pressure angle at given cam angle.
The pressure angle is arctan(dr/dtheta / r).
Args:
ctrl_points: B-spline control point radii
theta: Cam angle
Returns:
Pressure angle in radians
"""
r = eval_cam_radius(ctrl_points, theta)
dr_dtheta = jax.grad(lambda th: eval_cam_radius(ctrl_points, th))(theta)
return jnp.arctan2(jnp.abs(dr_dtheta), r)
Target motion profile#
We design a cam for a dwell-rise-dwell-return motion:
Dwell at low position (0-90°)
Rise to high position (90-180°)
Dwell at high position (180-270°)
Return to low position (270-360°)
We use a modified sine (cycloidal) motion for smooth acceleration.
def cycloidal_rise(beta: jax.Array) -> jax.Array:
"""Cycloidal rise motion: smooth acceleration profile.
Args:
beta: Normalized angle [0, 1].
Returns:
Displacement [0, 1].
"""
return beta - jnp.sin(2 * jnp.pi * beta) / (2 * jnp.pi)
def target_displacement(theta: jax.Array, lift: float = 0.3) -> jax.Array:
"""Target follower displacement for dwell-rise-dwell-return motion.
Args:
theta: Cam angle in radians.
lift: Total follower lift.
Returns:
Target displacement from base position.
"""
theta_deg = jnp.rad2deg(theta) % 360
# Dwell at low (0-90).
low_dwell = theta_deg < 90
# Rise (90-180).
rise = (theta_deg >= 90) & (theta_deg < 180)
rise_beta = (theta_deg - 90) / 90
# Dwell at high (180-270).
high_dwell = (theta_deg >= 180) & (theta_deg < 270)
# Return (270-360).
ret_beta = (theta_deg - 270) / 90
displacement = jnp.where(
low_dwell,
0.0,
jnp.where(
rise,
lift * cycloidal_rise(rise_beta),
jnp.where(high_dwell, lift, lift * (1 - cycloidal_rise(ret_beta))),
),
)
return displacement
# Base circle radius.
R_BASE = 1.0
LIFT = 0.3
# Generate target profile.
n_sample = 72
sample_thetas = jnp.linspace(0, 2 * jnp.pi, n_sample, endpoint=False)
target_radii = R_BASE + jax.vmap(lambda th: target_displacement(th, LIFT))(
sample_thetas
)
print("Target motion: dwell-rise-dwell-return")
print(f"Base radius: {R_BASE}, Lift: {LIFT}")
print(f"Sampling {n_sample} points around the cam")
Target motion: dwell-rise-dwell-return
Base radius: 1.0, Lift: 0.3
Sampling 72 points around the cam
Optimization variables and costs#
We optimize B-spline control points to match the target motion while satisfying constraints.
# Number of B-spline control points.
N_CTRL = 16
class CamControlPointsVar(
jaxls.Var[jax.Array],
default_factory=lambda: jnp.ones(N_CTRL) * R_BASE,
):
"""B-spline control point radii for cam profile."""
cam_var = CamControlPointsVar(id=0)
@jaxls.Cost.factory
def motion_tracking_cost(
vals: jaxls.VarValues,
var: CamControlPointsVar,
theta: jax.Array,
target_r: jax.Array,
) -> jax.Array:
"""Cost for cam radius to match target displacement."""
ctrl_pts = vals[var]
actual_r = eval_cam_radius(ctrl_pts, theta)
# Weight the error by 10 to prioritize motion tracking.
return (actual_r - target_r) * 10.0
@jaxls.Cost.factory(kind="constraint_leq_zero")
def min_radius_constraint(
vals: jaxls.VarValues,
var: CamControlPointsVar,
min_radius: float,
) -> jax.Array:
"""All control points must be >= minimum radius."""
ctrl_pts = vals[var]
return min_radius - ctrl_pts
@jaxls.Cost.factory(kind="constraint_leq_zero")
def pressure_angle_constraint(
vals: jaxls.VarValues,
var: CamControlPointsVar,
theta: jax.Array,
max_angle: float,
) -> jax.Array:
"""Pressure angle must be <= maximum allowed."""
ctrl_pts = vals[var]
alpha = compute_pressure_angle(ctrl_pts, theta)
return jnp.array([alpha - max_angle])
@jaxls.Cost.factory
def smoothness_cost(
vals: jaxls.VarValues,
var: CamControlPointsVar,
index: int,
n_ctrl: int,
) -> jax.Array:
"""Penalize second differences for smooth profile."""
ctrl_pts = vals[var]
# Second difference (discrete curvature)
prev_idx = (index - 1) % n_ctrl
next_idx = (index + 1) % n_ctrl
curvature = ctrl_pts[prev_idx] - 2 * ctrl_pts[index] + ctrl_pts[next_idx]
return jnp.array([curvature * 0.5])
Solve the cam design problem#
# Maximum allowed pressure angle (30 degrees is typical limit).
MAX_PRESSURE_ANGLE = jnp.deg2rad(30)
MIN_RADIUS = 0.7
# Build costs.
# Note: loops used for clarity; batched construction is more efficient.
costs: list[jaxls.Cost] = []
# Motion tracking at sample points.
for i in range(n_sample):
costs.append(motion_tracking_cost(cam_var, sample_thetas[i], target_radii[i]))
# Minimum radius constraint.
costs.append(min_radius_constraint(cam_var, MIN_RADIUS))
# Pressure angle constraints at key points (during rise and return phases).
pressure_check_angles = jnp.linspace(
jnp.deg2rad(95), jnp.deg2rad(175), 8
) # Rise phase.
pressure_check_angles = jnp.concatenate(
[pressure_check_angles, jnp.linspace(jnp.deg2rad(275), jnp.deg2rad(355), 8)]
) # Return phase.
for theta in pressure_check_angles:
costs.append(pressure_angle_constraint(cam_var, theta, MAX_PRESSURE_ANGLE))
# Smoothness regularization.
for i in range(N_CTRL):
costs.append(smoothness_cost(cam_var, i, N_CTRL))
# Initial guess: sample target radii at control point angles.
ctrl_indices = jnp.linspace(0, n_sample, N_CTRL, endpoint=False).astype(int)
initial_ctrl = target_radii[ctrl_indices]
# Create the problem.
problem = jaxls.LeastSquaresProblem(costs, [cam_var])
# Visualize the problem structure structure.
problem.show()
# Analyze and solve.
problem = problem.analyze()
solution = problem.solve(
initial_vals=jaxls.VarValues.make([cam_var.with_value(initial_ctrl)]),
termination=jaxls.TerminationConfig(cost_tolerance=1e-6, max_iterations=100),
)
opt_ctrl = solution[cam_var]
print(f"Optimized {N_CTRL} control points")
print(f"Radius range: [{float(jnp.min(opt_ctrl)):.3f}, {float(jnp.max(opt_ctrl)):.3f}]")
INFO | Building optimization problem with 105 terms and 1 variables: 88 costs, 0 eq_zero, 17 leq_zero, 0 geq_zero
INFO | Vectorizing group with 72 costs, 1 variables each: motion_tracking_cost
INFO | Vectorizing constraint group with 1 constraints (constraint_leq_zero), 1 variables each: augmented_min_radius_constraint
INFO | Vectorizing group with 16 costs, 1 variables each: smoothness_cost
INFO | Vectorizing constraint group with 16 constraints (constraint_leq_zero), 1 variables each: augmented_pressure_angle_constraint
INFO | Augmented Lagrangian: initial snorm=0.0000e+00, csupn=0.0000e+00, max_rho=9.2798e+02, constraint_dim=32
INFO | step #0: cost=92.7978 lambd=0.0005 inexact_tol=1.0e-02
INFO | - motion_tracking_cost(72): 92.78764 (avg 1.28872)
INFO | - augmented_min_radius_constraint(1): 0.00000 (avg 0.00000)
INFO | - smoothness_cost(16): 0.01011 (avg 0.00063)
INFO | - augmented_pressure_angle_constraint(16): 0.00000 (avg 0.00000)
INFO | accepted=True ATb_norm=2.03e+02 cost_prev=92.7977 cost_new=0.0211
INFO | step #1: cost=0.0211 lambd=0.0003 inexact_tol=1.0e-02
INFO | - motion_tracking_cost(72): 0.00084 (avg 0.00001)
INFO | - augmented_min_radius_constraint(1): 0.00000 (avg 0.00000)
INFO | - smoothness_cost(16): 0.02025 (avg 0.00127)
INFO | - augmented_pressure_angle_constraint(16): 0.00000 (avg 0.00000)
INFO | accepted=False ATb_norm=2.44e-04 cost_prev=0.0211 cost_new=0.0211
INFO | AL update: snorm=0.0000e+00, csupn=0.0000e+00, max_rho=9.2798e+02
INFO | Terminated @ iteration #2: cost=0.0211 criteria=[1 0 1], term_deltas=4.4e-07,1.4e-04,1.7e-07
Optimized 16 control points
Radius range: [0.999, 1.301]
# Evaluate optimized cam.
eval_thetas = jnp.linspace(0, 2 * jnp.pi, 360, endpoint=False)
achieved_radii = jax.vmap(lambda th: eval_cam_radius(opt_ctrl, th))(eval_thetas)
target_radii_fine = R_BASE + jax.vmap(lambda th: target_displacement(th, LIFT))(
eval_thetas
)
# Compute pressure angles.
pressure_angles = jax.vmap(lambda th: compute_pressure_angle(opt_ctrl, th))(eval_thetas)
# Compute errors.
tracking_error = achieved_radii - target_radii_fine
print("\nMotion tracking:")
print(f" RMS error: {float(jnp.sqrt(jnp.mean(tracking_error**2))):.4f}")
print(f" Max error: {float(jnp.max(jnp.abs(tracking_error))):.4f}")
print("\nPressure angle:")
print(f" Max: {float(jnp.rad2deg(jnp.max(pressure_angles))):.1f}°")
print(f" Limit: {float(jnp.rad2deg(MAX_PRESSURE_ANGLE)):.1f}°")
Motion tracking:
RMS error: 0.0003
Max error: 0.0009
Pressure angle:
Max: 18.1°
Limit: 30.0°
Visualization#
View the optimized cam profile and its motion characteristics:
Pressure angle analysis#
The pressure angle affects force transmission efficiency. High angles cause side loads on the follower.
The optimized cam profile closely tracks the desired dwell-rise-dwell-return motion while keeping the pressure angle within acceptable limits. The B-spline parameterization ensures a smooth, manufacturable profile.
Key design tradeoffs:
Larger base circle → lower pressure angles but larger cam size
Slower rise/return → lower pressure angles but reduced machine speed
More control points → better motion tracking but more complex profile