Cloth draping#
Note
This is a feature demonstration, not a practical simulator.
In this notebook, we simulate a cloth falling and draping over a sphere using Position-Based Dynamics (PBD), a popular approach for real-time cloth simulation in games and graphics.
PBD works by treating positions as the primary state and solving constraint optimization problems at each timestep:
Predict positions using current velocities and external forces (gravity)
Solve for positions that satisfy constraints (springs, collisions) while staying close to predictions
Update velocities from the position change:
v = (x_new - x_old) / dt
This approach is unconditionally stable and naturally handles position-based constraints like collisions and stretch limits.
Features used:
Varsubclassing for custom 3D point variables@jaxls.Cost.factoryfor potential energy and inertia termsInequality constraints (
constraint_geq_zero) for table and sphere collisionTime-stepping simulation with implicit integration
Viser’s
StateSerializerfor animated visualization
import sys
from loguru import logger
logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import jax
import jax.numpy as jnp
import jaxls
Variables and costs#
We define costs for springs, collisions, and an inertia term for time integration. The inertia cost penalizes deviation from the predicted position based on current velocity.
class Point3Var(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(3)):
"""A 3D point variable."""
@jaxls.Cost.factory
def spring_cost(
vals: jaxls.VarValues,
var_a: Point3Var,
var_b: Point3Var,
rest_length: jax.Array,
stiffness: float,
) -> jax.Array:
"""Elastic potential energy for a Hookean spring."""
diff = vals[var_a] - vals[var_b]
length = jnp.sqrt(jnp.sum(diff**2) + 1e-6)
direction = diff / length
return (length - rest_length) * jnp.sqrt(stiffness) * direction
@jaxls.Cost.factory
def inertia_cost(
vals: jaxls.VarValues,
var: Point3Var,
predicted_pos: jax.Array,
mass: float,
dt: float,
) -> jax.Array:
"""Inertia cost for implicit time integration.
Penalizes deviation from the predicted position (based on velocity and gravity).
The weight sqrt(mass)/dt ensures proper scaling with mass and timestep.
"""
pos = vals[var]
# Weight so that squared cost ~ mass * ||x - x_pred||^2 / dt^2
weight = jnp.sqrt(mass) / dt
return weight * (pos - predicted_pos)
@jaxls.Cost.factory(kind="constraint_geq_zero")
def table_constraint(
vals: jaxls.VarValues,
var: Point3Var,
) -> jax.Array:
"""Inequality constraint: point must stay above the table (z >= 0)."""
return vals[var][2]
@jaxls.Cost.factory(kind="constraint_geq_zero")
def sphere_constraint(
vals: jaxls.VarValues,
var: Point3Var,
center: jax.Array,
radius: float,
) -> jax.Array:
"""Inequality constraint: point must stay outside the sphere."""
pos = vals[var]
dist_from_center = jnp.sqrt(jnp.sum((pos - center) ** 2) + 1e-8)
return dist_from_center - radius
Grid setup#
Create a grid of particles starting above the sphere. We use a smaller grid for faster simulation.
# Grid dimensions (smaller for faster simulation).
cols, rows = 15, 15
num_points = cols * rows
spacing = 0.10
# Sphere parameters.
sphere_radius = 0.35
sphere_center = jnp.array(
[
(cols - 1) * spacing / 2, # Centered in x.
(rows - 1) * spacing / 2, # Centered in y.
sphere_radius, # Sitting on the table.
]
)
# Initial height: cloth starts well above the sphere.
initial_height = sphere_center[2] + sphere_radius + 1.5
# Offset cloth so it's not centered over the sphere (causes it to slip off).
cloth_offset = jnp.array([0.15, 0.05, 0.0])
def idx(row: int, col: int) -> int:
"""Convert (row, col) to flat index."""
return row * cols + col
# Initial positions: regular grid above the sphere, offset from center.
initial_positions = (
jnp.array(
[
[c * spacing, r * spacing, initial_height]
for r in range(rows)
for c in range(cols)
]
)
+ cloth_offset
)
# Initial velocities: all zeros.
velocities = jnp.zeros_like(initial_positions)
print(f"Grid: {cols}x{rows} = {num_points} points")
print(f"Cloth size: {(cols - 1) * spacing:.2f} x {(rows - 1) * spacing:.2f} units")
print(f"Sphere: center={sphere_center}, radius={sphere_radius}")
print(f"Initial cloth height: {initial_height:.2f} units")
print(f"Cloth offset: [{cloth_offset[0]:.2f}, {cloth_offset[1]:.2f}] units")
Grid: 15x15 = 225 points
Cloth size: 1.40 x 1.40 units
Sphere: center=[0.7 0.7 0.35], radius=0.35
Initial cloth height: 2.20 units
Cloth offset: [0.15, 0.05] units
# Create all point variables.
all_point_vars = Point3Var(id=jnp.arange(num_points))
free_indices = jnp.arange(num_points)
# Build spring connectivity arrays.
# Structural springs (adjacent neighbors).
struct_a, struct_b = [], []
for r in range(rows):
for c in range(cols - 1):
struct_a.append(idx(r, c))
struct_b.append(idx(r, c + 1))
for r in range(rows - 1):
for c in range(cols):
struct_a.append(idx(r, c))
struct_b.append(idx(r + 1, c))
struct_a = jnp.array(struct_a)
struct_b = jnp.array(struct_b)
struct_rest_length = spacing
# Shear springs (diagonal neighbors).
shear_a, shear_b = [], []
for r in range(rows - 1):
for c in range(cols - 1):
shear_a.append(idx(r, c))
shear_b.append(idx(r + 1, c + 1))
shear_a.append(idx(r, c + 1))
shear_b.append(idx(r + 1, c))
shear_a = jnp.array(shear_a)
shear_b = jnp.array(shear_b)
shear_rest_length = spacing * jnp.sqrt(2)
# Bend springs (skip-one neighbors).
bend_a, bend_b = [], []
for r in range(rows):
for c in range(cols - 2):
bend_a.append(idx(r, c))
bend_b.append(idx(r, c + 2))
for r in range(rows - 2):
for c in range(cols):
bend_a.append(idx(r, c))
bend_b.append(idx(r + 2, c))
bend_a = jnp.array(bend_a)
bend_b = jnp.array(bend_b)
bend_rest_length = spacing * 2
# Build triangle indices for aerodynamic drag and visualization.
tri_list = []
for r in range(rows - 1):
for c in range(cols - 1):
tri_list.append([idx(r, c), idx(r + 1, c), idx(r, c + 1)])
tri_list.append([idx(r + 1, c), idx(r + 1, c + 1), idx(r, c + 1)])
triangle_indices = jnp.array(tri_list)
print(f"Total springs: {len(struct_a) + len(shear_a) + len(bend_a)}")
print(f"Total triangles: {len(triangle_indices)}")
Total springs: 1202
Total triangles: 392
Simulation parameters#
# Time stepping.
dt = 1.0 / 30.0 # 30 Hz simulation.
total_time = 5.0 # seconds
num_steps = int(total_time / dt)
# Physics parameters.
mass_per_point = 0.01 # kg
g = 9.81 # m/s²
# Damping factors (multiply velocity each step).
fabric_damping = 0.95 # Internal cloth damping to reduce oscillations.
friction_damping = 0.7 # Stronger damping for vertices touching surfaces.
# Spring stiffness (lower for softer cloth).
structural_stiffness = 50.0
shear_stiffness = 25.0
bend_stiffness = 10.0
print(f"Simulating {num_steps} steps at {1 / dt:.0f} Hz for {total_time:.1f}s")
Simulating 150 steps at 30 Hz for 5.0s
Build problem factory#
We create a function that builds the optimization problem for each timestep. The problem structure stays the same, only the predicted positions change.
def build_problem(
predicted_positions: jax.Array,
) -> jaxls.LeastSquaresProblem:
"""Build optimization problem for one timestep."""
sphere_centers = jnp.tile(sphere_center[None, :], (num_points, 1))
costs: list[jaxls.Cost] = [
# Inertia: follow predicted trajectory.
inertia_cost(
Point3Var(id=free_indices),
predicted_positions,
mass_per_point,
dt,
),
# Collision constraints.
table_constraint(Point3Var(id=free_indices)),
sphere_constraint(
Point3Var(id=free_indices),
sphere_centers,
sphere_radius,
),
# Springs.
spring_cost(
Point3Var(id=struct_a),
Point3Var(id=struct_b),
struct_rest_length,
structural_stiffness,
),
spring_cost(
Point3Var(id=shear_a),
Point3Var(id=shear_b),
shear_rest_length,
shear_stiffness,
),
spring_cost(
Point3Var(id=bend_a),
Point3Var(id=bend_b),
bend_rest_length,
bend_stiffness,
),
]
return jaxls.LeastSquaresProblem(costs, [all_point_vars])
Run simulation#
At each timestep:
Compute predicted positions from current velocity and gravity
Solve optimization to find positions that balance inertia, springs, and constraints
Update velocities from position change
Apply damping
# Gravity vector.
gravity_vec = jnp.array([0.0, 0.0, -g])
# Augmented Lagrangian config with tighter tolerances for better constraint satisfaction.
al_config = jaxls.AugmentedLagrangianConfig(
inner_solve_tolerance=1e-3,
tolerance_absolute=1e-6,
tolerance_relative=1e-5,
)
def simulation_step(
state: tuple[jax.Array, jax.Array],
_: None,
) -> tuple[tuple[jax.Array, jax.Array], jax.Array]:
"""One step of the PBD simulation."""
positions, velocities = state
# Predict positions using velocity Verlet integration.
predicted = positions + velocities * dt + 0.5 * gravity_vec * dt**2
# Build and analyze problem for this timestep.
problem = build_problem(predicted).analyze()
# Solve.
initial_vals = jaxls.VarValues.make([all_point_vars.with_value(positions)])
new_positions = problem.solve(
initial_vals,
linear_solver="cholmod",
verbose=False,
augmented_lagrangian=al_config,
)[all_point_vars]
# Update velocities.
new_velocities = (new_positions - positions) / dt
# Apply contact-aware damping: stronger for vertices touching surfaces.
near_table = new_positions[:, 2] < 0.01
dist_to_sphere = jnp.sqrt(jnp.sum((new_positions - sphere_center) ** 2, axis=1))
near_sphere = dist_to_sphere < (sphere_radius + 0.01)
in_contact = near_table | near_sphere
new_velocities = jnp.where(
in_contact[:, None],
new_velocities * friction_damping,
new_velocities * fabric_damping,
)
return (new_positions, new_velocities), new_positions
# Suppress verbose logging during simulation.
logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}", level="WARNING")
print("Running simulation with jax.lax.scan...")
# Run simulation using scan.
initial_state = (initial_positions, jnp.zeros_like(initial_positions))
(final_positions, final_velocities), trajectory_array = jax.lax.scan(
simulation_step,
initial_state,
None,
length=num_steps,
)
# Convert to list and prepend initial frame.
trajectory = [initial_positions] + [trajectory_array[i] for i in range(num_steps)]
print(
f" Final height range: [{float(final_positions[:, 2].min()):.3f}, {float(final_positions[:, 2].max()):.3f}]"
)
print(f"Simulation complete! {len(trajectory)} frames recorded.")
Running simulation with jax.lax.scan...
Final height range: [-0.000, 0.176]
Simulation complete! 151 frames recorded.
Animated visualization#
Use Viser’s StateSerializer to create an animated playback of the simulation.
The visualization mesh is upsampled from the simulation grid using bicubic
interpolation for smoother rendering.
import contextlib
import io
import numpy as np
import scipy.ndimage
import trimesh
import viser
# Visualization grid: finer than simulation for smoother rendering.
vis_cols, vis_rows = 45, 45
# Precompute fine grid coordinates (in terms of coarse grid indices).
fine_row_coords = np.linspace(0, rows - 1, vis_rows)
fine_col_coords = np.linspace(0, cols - 1, vis_cols)
fine_row_grid, fine_col_grid = np.meshgrid(
fine_row_coords, fine_col_coords, indexing="ij"
)
interp_coords = np.stack([fine_row_grid, fine_col_grid], axis=0)
def upsample_positions(coarse_positions: np.ndarray) -> np.ndarray:
"""Upsample coarse grid positions to fine visualization grid using bicubic interpolation."""
# Reshape to grid: (rows, cols, 3).
coarse_grid = np.asarray(coarse_positions).reshape(rows, cols, 3)
# Interpolate each xyz component using bicubic interpolation (order=3).
fine_x = scipy.ndimage.map_coordinates(coarse_grid[:, :, 0], interp_coords, order=3)
fine_y = scipy.ndimage.map_coordinates(coarse_grid[:, :, 1], interp_coords, order=3)
fine_z = scipy.ndimage.map_coordinates(coarse_grid[:, :, 2], interp_coords, order=3)
# Stack and flatten: (vis_rows * vis_cols, 3).
return np.stack([fine_x, fine_y, fine_z], axis=-1).reshape(-1, 3)
# Build triangle indices for fine visualization grid.
def vis_idx(row: int, col: int) -> int:
return row * vis_cols + col
vis_tri_list = []
for r in range(vis_rows - 1):
for c in range(vis_cols - 1):
vis_tri_list.append([vis_idx(r, c), vis_idx(r + 1, c), vis_idx(r, c + 1)])
vis_tri_list.append(
[vis_idx(r + 1, c), vis_idx(r + 1, c + 1), vis_idx(r, c + 1)]
)
vis_triangles = np.array(vis_tri_list)
# Center the scene at origin.
cloth_center = np.array(
[
(cols - 1) * spacing / 2,
(rows - 1) * spacing / 2,
0.0,
]
)
centered_sphere_center = np.array(sphere_center) - cloth_center
# Sphere mesh.
sphere_trimesh = trimesh.creation.icosphere(
subdivisions=3, radius=float(sphere_radius) - 2e-2
)
sphere_vertices = np.array(sphere_trimesh.vertices) + centered_sphere_center
sphere_faces = np.array(sphere_trimesh.faces)
# Create Viser server.
with (
contextlib.redirect_stdout(io.StringIO()),
contextlib.redirect_stderr(io.StringIO()),
):
server = viser.ViserServer(verbose=False)
# Set initial camera position for a good view of the scene.
server.initial_camera.position = (1.5, -1.5, 1.0)
server.initial_camera.look_at = (0.0, 0.0, 0.3)
server.scene.add_mesh_simple(
"/sphere",
vertices=sphere_vertices,
faces=sphere_faces,
color=(200, 80, 80),
)
server.scene.add_grid(
"/table",
width=10.0,
height=10.0,
plane="xy",
position=(0, 0, -0.01),
infinite_grid=True,
fade_distance=50.0,
cell_color=(200, 200, 200),
section_color=(170, 170, 170),
)
server.scene.set_up_direction("+z")
# Pre-create all cloth frames as invisible meshes (upsampled for smoother visualization).
cloth_handles = []
for i, frame_positions in enumerate(trajectory):
fine_positions = upsample_positions(frame_positions)
handle = server.scene.add_mesh_simple(
f"/cloth_{i}",
vertices=fine_positions - cloth_center,
faces=vis_triangles,
flat_shading=False,
side="double",
color=(100, 150, 220),
visible=False,
)
cloth_handles.append(handle)
# Create animation using StateSerializer.
serializer = server.get_scene_serializer()
# Animate by toggling visibility.
for i, handle in enumerate(cloth_handles):
if i > 0:
cloth_handles[i - 1].visible = False
handle.visible = True
serializer.insert_sleep(dt)
serializer.show(height=500)
╭────── viser (listening *:8082) ───────╮ │ ╷ │ │ HTTP │ http://localhost:8082 │ │ Websocket │ ws://localhost:8082 │ │ ╵ │ ╰───────────────────────────────────────╯
The animation shows the cloth falling under gravity and draping over the sphere. The simulation uses implicit time integration:
Inertia term: Keeps the cloth following its predicted trajectory based on velocity
Spring forces: Maintain cloth structure and resist stretching/shearing/bending
Collision constraints: Prevent the cloth from penetrating the sphere or table
Friction: Contact-aware velocity damping helps the cloth settle on surfaces
The solver finds positions at each timestep that balance inertia, springs, and collision constraints. Friction is applied outside the optimization as velocity damping for vertices in contact with surfaces.