Truss analysis#

In this notebook, we solve a truss analysis problem: computing forces and deformations in a 2D pin-jointed frame under load.

Features used:

  • Var subclassing for node displacement variables

  • @jaxls.Cost.factory for bar element strain energy

  • Equality constraints for fixed supports

  • Batched cost construction for all members

This is a classic introductory finite element analysis (FEA) problem using 1D bar elements.

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import jax
import jax.numpy as jnp
import jaxls

Truss element theory#

A truss is a structure of bar elements connected at pin joints (nodes). Each bar:

  • Carries only axial force (tension or compression)

  • Has stiffness \(k = \frac{EA}{L}\) where \(E\) is Young’s modulus, \(A\) is cross-sectional area, \(L\) is length

The strain energy in a bar element is

\[U = \frac{1}{2} k (\Delta L)^2 = \frac{1}{2} \frac{EA}{L} (L' - L)^2\]

where \(L'\) is the deformed length.

class NodeVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(2)):
    """2D node displacement variable [dx, dy] in meters."""

Cost functions#

  1. Bar strain energy: Penalizes elongation/compression of each member

  2. Support constraints: Fix displacements at support nodes

  3. Load application: Prescribe displacement at load point (equilibrium is automatic)

@jaxls.Cost.factory
def bar_strain_energy(
    vals: jaxls.VarValues,
    node_i: NodeVar,
    node_j: NodeVar,
    pos_i: jax.Array,
    pos_j: jax.Array,
    EA: float,
) -> jax.Array:
    """Strain energy in a bar element: (1/2) * EA/L * (delta_L)^2.

    Args:
        node_i, node_j: Displacement variables at each end.
        pos_i, pos_j: Initial (undeformed) positions.
        EA: Axial stiffness (Young's modulus × area).
    """
    # Initial geometry.
    L0_vec = pos_j - pos_i
    L0 = jnp.sqrt(jnp.sum(L0_vec**2))

    # Deformed geometry.
    disp_i = vals[node_i]
    disp_j = vals[node_j]
    L_vec = L0_vec + (disp_j - disp_i)
    L = jnp.sqrt(jnp.sum(L_vec**2))

    # Return 2D residual instead of scalar: ||r||^2 = (EA/L0) * (L - L0)^2.
    # Using a 2D residual gives a rank-2 contribution to J^T J (the Gauss-Newton.
    # Hessian approximation), rather than rank-1 from a scalar residual.
    return jnp.sqrt(EA / L0) * (1 - L0 / L) * L_vec


@jaxls.Cost.factory(kind="constraint_eq_zero")
def pin_support(
    vals: jaxls.VarValues,
    node: NodeVar,
) -> jax.Array:
    """Pin support: both displacement components are zero."""
    return vals[node]


@jaxls.Cost.factory(kind="constraint_eq_zero")
def prescribed_displacement(
    vals: jaxls.VarValues,
    node: NodeVar,
    target_displacement: jax.Array,
) -> jax.Array:
    """Prescribe displacement at a node."""
    return vals[node] - target_displacement

Truss geometry#

We model a Warren truss, a common bridge structure with diagonal members:

   5-----6-----7-----8
  /\    /\    /\    /\
 /  \  /  \  /  \  /  \
/    \/    \/    \/    \
0-----1-----2-----3-----4
  • Nodes 0-4: Bottom chord

  • Nodes 5-8: Top chord

  • Nodes 0 and 4 are pinned (fixed in x and y)

  • Load applied at center bottom node (node 2)

# Geometry: Warren truss bridge.
num_panels = 4  # Number of triangular panels
panel_width = 3.0  # [m] width of each panel
height = 2.0  # [m] truss height
span = num_panels * panel_width  # Total span

# Build node positions.
bottom_nodes = [[i * panel_width, 0.0] for i in range(num_panels + 1)]
top_nodes = [[(i + 0.5) * panel_width, height] for i in range(num_panels)]
node_positions = jnp.array(bottom_nodes + top_nodes)
num_nodes = len(node_positions)

# Node indices.
bottom_ids = list(range(num_panels + 1))  # 0, 1, 2, 3, 4
top_ids = list(range(num_panels + 1, num_nodes))  # 5, 6, 7, 8

# Build member connectivity.
member_list = []
# Bottom chord.
for i in range(num_panels):
    member_list.append([bottom_ids[i], bottom_ids[i + 1]])
# Top chord.
for i in range(num_panels - 1):
    member_list.append([top_ids[i], top_ids[i + 1]])
# Diagonals (left and right of each top node)
for i in range(num_panels):
    member_list.append([bottom_ids[i], top_ids[i]])  # Left diagonal
    member_list.append([top_ids[i], bottom_ids[i + 1]])  # Right diagonal

members = jnp.array(member_list)
num_members = len(members)

# Material properties.
EA = 50000.0  # [N] axial stiffness

# Load node.
load_node_id = 2  # Center bottom node

# Prescribed displacement (downward)
load_displacement = jnp.array([0.0, -0.02])  # 20 mm downward

print("Warren Truss Bridge:")
print(f"  Span: {span} m, Height: {height} m")
print(f"  Nodes: {num_nodes}, Members: {num_members}")
print(f"  Member stiffness EA = {EA:.0f} N")
print(
    f"  Prescribed displacement at node {load_node_id}: {float(load_displacement[1]) * 1000:.1f} mm (vertical)"
)
Warren Truss Bridge:
  Span: 12.0 m, Height: 2.0 m
  Nodes: 9, Members: 15
  Member stiffness EA = 50000 N
  Prescribed displacement at node 2: -20.0 mm (vertical)

Problem construction#

# Create node displacement variables.
node_vars = NodeVar(id=jnp.arange(num_nodes))

# Support nodes (both pinned).
left_pin = 0
right_pin = num_panels

# Build costs.
costs: list[jaxls.Cost] = [
    # Strain energy in all members (batched).
    bar_strain_energy(
        NodeVar(id=members[:, 0]),
        NodeVar(id=members[:, 1]),
        node_positions[members[:, 0]],
        node_positions[members[:, 1]],
        EA,
    ),
    # Boundary conditions.
    pin_support(NodeVar(id=left_pin)),
    pin_support(NodeVar(id=right_pin)),
    # Applied load via prescribed displacement.
    prescribed_displacement(NodeVar(id=load_node_id), load_displacement),
]

print(f"Created {len(costs)} cost objects")
print(f"Load applied at node {load_node_id}")
Created 4 cost objects
Load applied at node 2

Solving#

# Initial values: zero displacement.
initial_displacements = jnp.zeros((num_nodes, 2))
initial_vals = jaxls.VarValues.make([node_vars.with_value(initial_displacements)])

# Build the problem.
problem = jaxls.LeastSquaresProblem(costs, [node_vars])

# Visualize the problem structure structure.
problem.show()
# Analyze and solve.
problem = problem.analyze()
solution = problem.solve(initial_vals)
INFO     | Building optimization problem with 18 terms and 9 variables: 15 costs, 3 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing constraint group with 2 constraints (constraint_eq_zero), 1 variables each: augmented_pin_support
INFO     | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_prescribed_displacement
INFO     | Vectorizing group with 15 costs, 2 variables each: bar_strain_energy
INFO     | Augmented Lagrangian: initial snorm=2.0000e-02, csupn=2.0000e-02, max_rho=1.0000e+01, constraint_dim=6
INFO     |  step #0: cost=0.0000 lambd=0.0005 inexact_tol=1.0e-02
INFO     |      - augmented_pin_support(2): 0.00000 (avg 0.00000)
INFO     |      - augmented_prescribed_displacement(1): 0.00400 (avg 0.00200)
INFO     |      - bar_strain_energy(15): 0.00000 (avg 0.00000)
INFO     |      accepted=True ATb_norm=2.02e-01 cost_prev=0.0040 cost_new=0.0027
INFO     |  step #1: cost=0.0000 lambd=0.0003 inexact_tol=1.0e-02
INFO     |      - augmented_pin_support(2): 0.00088 (avg 0.00022)
INFO     |      - augmented_prescribed_displacement(1): 0.00177 (avg 0.00088)
INFO     |      - bar_strain_energy(15): 0.00001 (avg 0.00000)
INFO     |      accepted=True ATb_norm=8.03e-03 cost_prev=0.0027 cost_new=0.0027
INFO     |  AL update: snorm=1.3288e-02, csupn=1.3288e-02, max_rho=4.0000e+01
INFO     |  step #2: cost=0.0000 lambd=0.0001 inexact_tol=1.4e-03
INFO     |      - augmented_pin_support(2): 0.00552 (avg 0.00138)
INFO     |      - augmented_prescribed_displacement(1): 0.01104 (avg 0.00552)
INFO     |      - bar_strain_energy(15): 0.00001 (avg 0.00000)
INFO     |      accepted=True ATb_norm=6.56e-01 cost_prev=0.0166 cost_new=0.0164
INFO     |  step #3: cost=0.0002 lambd=0.0001 inexact_tol=1.4e-03
INFO     |      - augmented_pin_support(2): 0.00540 (avg 0.00135)
INFO     |      - augmented_prescribed_displacement(1): 0.01079 (avg 0.00540)
INFO     |      - bar_strain_energy(15): 0.00023 (avg 0.00001)
INFO     |      accepted=True ATb_norm=8.96e-03 cost_prev=0.0164 cost_new=0.0164
INFO     |  AL update: snorm=1.3102e-02, csupn=1.3102e-02, max_rho=1.6000e+02
INFO     |  step #4: cost=0.0002 lambd=0.0000 inexact_tol=1.7e-04
INFO     |      - augmented_pin_support(2): 0.02369 (avg 0.00592)
INFO     |      - augmented_prescribed_displacement(1): 0.04738 (avg 0.02369)
INFO     |      - bar_strain_energy(15): 0.00023 (avg 0.00001)
INFO     |      accepted=True ATb_norm=2.59e+00 cost_prev=0.0713 cost_new=0.0691
INFO     |  step #5: cost=0.0036 lambd=0.0000 inexact_tol=1.7e-04
INFO     |      - augmented_pin_support(2): 0.02184 (avg 0.00546)
INFO     |      - augmented_prescribed_displacement(1): 0.04364 (avg 0.02182)
INFO     |      - bar_strain_energy(15): 0.00365 (avg 0.00012)
INFO     |      accepted=True ATb_norm=6.89e-03 cost_prev=0.0691 cost_new=0.0691
INFO     |  AL update: snorm=1.2409e-02, csupn=1.2409e-02, max_rho=6.4000e+02
INFO     |  step #6: cost=0.0036 lambd=0.0000 inexact_tol=6.4e-06
INFO     |      - augmented_pin_support(2): 0.08763 (avg 0.02191)
INFO     |      - augmented_prescribed_displacement(1): 0.17503 (avg 0.08751)
INFO     |      - bar_strain_energy(15): 0.00365 (avg 0.00012)
INFO     |      accepted=True ATb_norm=9.80e+00 cost_prev=0.2663 cost_new=0.2399
INFO     |  step #7: cost=0.0425 lambd=0.0000 inexact_tol=6.4e-06
INFO     |      - augmented_pin_support(2): 0.06635 (avg 0.01659)
INFO     |      - augmented_prescribed_displacement(1): 0.13097 (avg 0.06548)
INFO     |      - bar_strain_energy(15): 0.04254 (avg 0.00142)
INFO     |      accepted=False ATb_norm=6.01e-02 cost_prev=0.2399 cost_new=0.2399
INFO     |  AL update: snorm=1.0176e-02, csupn=1.0176e-02, max_rho=2.5600e+03
INFO     |  step #8: cost=0.0425 lambd=0.0000 inexact_tol=6.4e-06
INFO     |      - augmented_pin_support(2): 0.24687 (avg 0.06172)
INFO     |      - augmented_prescribed_displacement(1): 0.48419 (avg 0.24210)
INFO     |      - bar_strain_energy(15): 0.04254 (avg 0.00142)
INFO     |      accepted=True ATb_norm=3.22e+01 cost_prev=0.7736 cost_new=0.5982
INFO     |  step #9: cost=0.2582 lambd=0.0000 inexact_tol=6.4e-06
INFO     |      - augmented_pin_support(2): 0.12402 (avg 0.03101)
INFO     |      - augmented_prescribed_displacement(1): 0.21590 (avg 0.10795)
INFO     |      - bar_strain_energy(15): 0.25824 (avg 0.00861)
INFO     |      accepted=True ATb_norm=1.81e-01 cost_prev=0.5982 cost_new=0.5981
INFO     |  AL update: snorm=5.6058e-03, csupn=5.6058e-03, max_rho=1.0240e+04
INFO     |  step #10: cost=0.2583 lambd=0.0000 inexact_tol=6.4e-06
INFO     |      - augmented_pin_support(2): 0.40213 (avg 0.10053)
INFO     |      - augmented_prescribed_displacement(1): 0.63929 (avg 0.31965)
INFO     |      - bar_strain_energy(15): 0.25834 (avg 0.00861)
INFO     |      accepted=True ATb_norm=7.44e+01 cost_prev=1.2998 cost_new=0.9458
INFO     |  step #11: cost=0.6498 lambd=0.0000 inexact_tol=6.4e-06
INFO     |      - augmented_pin_support(2): 0.14261 (avg 0.03565)
INFO     |      - augmented_prescribed_displacement(1): 0.15344 (avg 0.07672)
INFO     |      - bar_strain_energy(15): 0.64977 (avg 0.02166)
INFO     |      accepted=False ATb_norm=1.49e-01 cost_prev=0.9458 cost_new=0.9459
INFO     |  AL update: snorm=1.5755e-03, csupn=1.5755e-03, max_rho=1.0240e+04
INFO     |  step #12: cost=0.6498 lambd=0.0000 inexact_tol=6.4e-06
INFO     |      - augmented_pin_support(2): 0.35443 (avg 0.08861)
INFO     |      - augmented_prescribed_displacement(1): 0.30377 (avg 0.15188)
INFO     |      - bar_strain_energy(15): 0.64977 (avg 0.02166)
INFO     |      accepted=True ATb_norm=2.80e+01 cost_prev=1.3080 cost_new=1.2693
INFO     |  step #13: cost=0.8255 lambd=0.0000 inexact_tol=6.4e-06
INFO     |      - augmented_pin_support(2): 0.24262 (avg 0.06066)
INFO     |      - augmented_prescribed_displacement(1): 0.20120 (avg 0.10060)
INFO     |      - bar_strain_energy(15): 0.82545 (avg 0.02752)
INFO     |      accepted=True ATb_norm=3.85e-02 cost_prev=1.2693 cost_new=1.2692
INFO     |  AL update: snorm=8.3995e-04, csupn=8.3995e-04, max_rho=4.0960e+04
INFO     |  step #14: cost=0.8254 lambd=0.0000 inexact_tol=1.7e-06
INFO     |      - augmented_pin_support(2): 0.24099 (avg 0.06025)
INFO     |      - augmented_prescribed_displacement(1): 0.25535 (avg 0.12768)
INFO     |      - bar_strain_energy(15): 0.82543 (avg 0.02751)
INFO     |      accepted=True ATb_norm=5.20e+01 cost_prev=1.3218 cost_new=1.2709
INFO     |  step #15: cost=0.9478 lambd=0.0000 inexact_tol=1.7e-06
INFO     |      - augmented_pin_support(2): 0.08918 (avg 0.02229)
INFO     |      - augmented_prescribed_displacement(1): 0.23397 (avg 0.11698)
INFO     |      - bar_strain_energy(15): 0.94777 (avg 0.03159)
INFO     |      accepted=False ATb_norm=2.07e-02 cost_prev=1.2709 cost_new=1.2709
INFO     |  AL update: snorm=3.4766e-04, csupn=3.4766e-04, max_rho=4.0960e+04
INFO     |  step #16: cost=0.9478 lambd=0.0000 inexact_tol=1.7e-06
INFO     |      - augmented_pin_support(2): 0.12436 (avg 0.03109)
INFO     |      - augmented_prescribed_displacement(1): 0.09748 (avg 0.04874)
INFO     |      - bar_strain_energy(15): 0.94777 (avg 0.03159)
INFO     |      accepted=True ATb_norm=1.85e+01 cost_prev=1.1696 cost_new=1.1627
INFO     |  step #17: cost=1.0021 lambd=0.0000 inexact_tol=1.7e-06
INFO     |      - augmented_pin_support(2): 0.09877 (avg 0.02469)
INFO     |      - augmented_prescribed_displacement(1): 0.06189 (avg 0.03094)
INFO     |      - bar_strain_energy(15): 1.00208 (avg 0.03340)
INFO     |      accepted=True ATb_norm=1.52e-02 cost_prev=1.1627 cost_new=1.1627
INFO     |  AL update: snorm=5.4568e-05, csupn=5.4568e-05, max_rho=4.0960e+04
INFO     |  step #18: cost=1.0021 lambd=0.0000 inexact_tol=6.0e-07
INFO     |      - augmented_pin_support(2): 0.10889 (avg 0.02722)
INFO     |      - augmented_prescribed_displacement(1): 0.06537 (avg 0.03269)
INFO     |      - bar_strain_energy(15): 1.00206 (avg 0.03340)
INFO     |      accepted=True ATb_norm=3.62e+00 cost_prev=1.1763 cost_new=1.1761
INFO     |  step #19: cost=1.0125 lambd=0.0000 inexact_tol=6.0e-07
INFO     |      - augmented_pin_support(2): 0.10107 (avg 0.02527)
INFO     |      - augmented_prescribed_displacement(1): 0.06252 (avg 0.03126)
INFO     |      - bar_strain_energy(15): 1.01249 (avg 0.03375)
INFO     |      accepted=False ATb_norm=1.04e-02 cost_prev=1.1761 cost_new=1.1761
INFO     |  AL update: snorm=1.3225e-05, csupn=1.3225e-05, max_rho=4.0960e+04
INFO     |  step #20: cost=1.0125 lambd=0.0000 inexact_tol=6.0e-07
INFO     |      - augmented_pin_support(2): 0.10341 (avg 0.02585)
INFO     |      - augmented_prescribed_displacement(1): 0.06317 (avg 0.03158)
INFO     |      - bar_strain_energy(15): 1.01249 (avg 0.03375)
INFO     |      accepted=True ATb_norm=8.31e-01 cost_prev=1.1791 cost_new=1.1790
INFO     |  step #21: cost=1.0147 lambd=0.0000 inexact_tol=6.0e-07
INFO     |      - augmented_pin_support(2): 0.10162 (avg 0.02541)
INFO     |      - augmented_prescribed_displacement(1): 0.06267 (avg 0.03133)
INFO     |      - bar_strain_energy(15): 1.01474 (avg 0.03382)
INFO     |      accepted=True ATb_norm=1.68e-02 cost_prev=1.1790 cost_new=1.1790
INFO     |  AL update: snorm=3.1461e-06, csupn=3.1461e-06, max_rho=4.0960e+04
INFO     | Terminated @ iteration #22: cost=1.0147 criteria=[0 0 1], term_deltas=4.0e-05,9.9e-03,2.3e-08

Results and visualization#

Hide code cell source

# Extract displacements.
displacements = solution[node_vars]
deformed_positions = node_positions + displacements


def compute_member_force(i: int, j: int, disp: jax.Array) -> jax.Array:
    """Compute axial force in member (positive = tension).

    Args:
        i: Start node index
        j: End node index
        disp: Node displacements array (num_nodes, 2)

    Returns:
        Axial force in the member (scalar)
    """
    L0_vec = node_positions[j] - node_positions[i]
    L0 = jnp.sqrt(jnp.sum(L0_vec**2))
    L_vec = L0_vec + (disp[j] - disp[i])
    L = jnp.sqrt(jnp.sum(L_vec**2))
    strain = (L - L0) / L0
    return EA * strain  # Force = EA * strain


member_forces = jax.vmap(lambda m: compute_member_force(m[0], m[1], displacements))(
    members
)

# Print results.
print("Node Displacements:")
print(f"{'Node':>4} {'dx [mm]':>10} {'dy [mm]':>10}")
print("-" * 26)
for i in range(num_nodes):
    dx, dy = displacements[i] * 1000  # Convert to mm
    print(f"{i:>4} {float(dx):>10.3f} {float(dy):>10.3f}")

print("\nMember Forces:")
print(f"{'Member':>6} {'Force [kN]':>12} {'Type':>10}")
print("-" * 30)
for idx, m in enumerate(members):
    f = member_forces[idx]
    f_kN = float(f) / 1000
    typ = "Tension" if f > 0 else "Compression"
    print(f"{int(m[0])}-{int(m[1]):>2} {f_kN:>12.2f} {typ:>10}")
Node Displacements:
Node    dx [mm]    dy [mm]
--------------------------
   0     -0.003     -0.001
   1     -1.153    -11.739
   2      0.000    -19.999
   3      1.153    -11.739
   4      0.003     -0.001
   5      4.598     -5.452
   6      2.287    -16.315
   7     -2.287    -16.315
   8     -4.598     -5.452

Member Forces:
Member   Force [kN]       Type
------------------------------
0- 1        -0.02 Compression
1- 2         0.02    Tension
2- 3         0.02    Tension
3- 4        -0.02 Compression
5- 6        -0.04 Compression
6- 7        -0.08 Compression
7- 8        -0.04 Compression
0- 5        -0.03 Compression
5- 1         0.03    Tension
1- 6        -0.03 Compression
6- 2         0.03    Tension
2- 7         0.03    Tension
7- 3        -0.03 Compression
3- 8         0.03    Tension
8- 4        -0.03 Compression

Hide code cell source

import plotly.graph_objects as go
from IPython.display import HTML

# Visualization.
scale = 20  # Displacement magnification for visibility
scaled_deformed = node_positions + scale * displacements

fig = go.Figure()

# Original structure (gray)
for m in members:
    i, j = int(m[0]), int(m[1])
    fig.add_trace(
        go.Scatter(
            x=[float(node_positions[i, 0]), float(node_positions[j, 0])],
            y=[float(node_positions[i, 1]), float(node_positions[j, 1])],
            mode="lines",
            line=dict(color="lightgray", width=6),
            showlegend=False,
        )
    )

# Deformed structure (colored by force)
max_force = float(jnp.max(jnp.abs(member_forces))) + 1e-6
for idx, m in enumerate(members):
    i, j = int(m[0]), int(m[1])
    force = float(member_forces[idx])
    # Color: blue for compression, red for tension.
    intensity = min(abs(force) / max_force, 1.0)
    if force > 0:
        color = f"rgba(220, {int(80 + 175 * (1 - intensity))}, {int(80 + 175 * (1 - intensity))}, 1)"
    else:
        color = f"rgba({int(80 + 175 * (1 - intensity))}, {int(80 + 175 * (1 - intensity))}, 220, 1)"

    fig.add_trace(
        go.Scatter(
            x=[float(scaled_deformed[i, 0]), float(scaled_deformed[j, 0])],
            y=[float(scaled_deformed[i, 1]), float(scaled_deformed[j, 1])],
            mode="lines",
            line=dict(color=color, width=5),
            showlegend=False,
        )
    )

# Nodes.
fig.add_trace(
    go.Scatter(
        x=[float(p) for p in scaled_deformed[:, 0]],
        y=[float(p) for p in scaled_deformed[:, 1]],
        mode="markers+text",
        marker=dict(size=10, color="steelblue"),
        text=[str(i) for i in range(num_nodes)],
        textposition="top center",
        textfont=dict(size=9),
        showlegend=False,
    )
)

# Load arrow at load node (pointing downward).
fig.add_annotation(
    x=float(scaled_deformed[load_node_id, 0]),
    y=float(scaled_deformed[load_node_id, 1]) - 0.8,
    ax=0,
    ay=-40,
    xref="x",
    yref="y",
    axref="pixel",
    ayref="pixel",
    showarrow=True,
    arrowhead=2,
    arrowsize=1.5,
    arrowwidth=3,
    arrowcolor="red",
)
fig.add_annotation(
    x=float(scaled_deformed[load_node_id, 0]),
    y=float(scaled_deformed[load_node_id, 1]) - 1.3,
    text="Load",
    showarrow=False,
    font=dict(size=12, color="red"),
)

fig.update_layout(
    title=f"Warren Truss Bridge Analysis (displacements x{scale})",
    xaxis=dict(title="x [m]", scaleanchor="y", scaleratio=1),
    yaxis=dict(title="y [m]"),
    height=400,
    showlegend=False,
    margin=dict(t=80, b=50, l=50, r=50),
)

# Add color legend.
fig.add_annotation(
    x=0.02,
    y=0.98,
    xref="paper",
    yref="paper",
    text="Red = Tension, Blue = Compression",
    showarrow=False,
    font=dict(size=10),
    align="left",
    bgcolor="white",
    bordercolor="gray",
    borderwidth=1,
)

HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

The solver found the equilibrium configuration of the Warren truss bridge under a prescribed displacement:

  • Deformed shape: Shown with exaggerated displacements for visibility

  • Member colors: Red indicates tension, blue indicates compression

Key observations:

  • Bottom chord is in tension (red) - it resists the spreading of the supports

  • Top chord is in compression (blue) - it’s being squeezed as the bridge sags

  • Diagonal members alternate between tension and compression

  • Maximum deflection occurs at the center where the load is applied

This is the classic behavior of a simply-supported truss bridge under a center point load.

Varying displacements#

We can animate the truss response to different prescribed displacements. As the displacement increases, the internal forces grow proportionally (since we’re in the linear elastic regime).

Using jax.vmap, we solve for all displacement magnitudes in parallel.

Hide code cell source

# Solve for multiple displacement magnitudes using vmap.
displacement_magnitudes = jnp.linspace(0, 0.05, 21)  # 0 to 50 mm.


def solve_for_displacement(disp_y: jax.Array) -> jax.Array:
    """Solve truss for a given prescribed vertical displacement."""
    target_disp = jnp.array([0.0, -disp_y])
    costs_d: list[jaxls.Cost] = [
        bar_strain_energy(
            NodeVar(id=members[:, 0]),
            NodeVar(id=members[:, 1]),
            node_positions[members[:, 0]],
            node_positions[members[:, 1]],
            EA,
        ),
        pin_support(NodeVar(id=left_pin)),
        pin_support(NodeVar(id=right_pin)),
        prescribed_displacement(NodeVar(id=load_node_id), target_disp),
    ]
    problem_d = jaxls.LeastSquaresProblem(costs_d, [node_vars]).analyze()
    sol = problem_d.solve(verbose=False)
    return sol[node_vars]


# Use vmap to solve for all displacement values in parallel.
all_displacements = jax.vmap(solve_for_displacement)(displacement_magnitudes)
print(
    f"Solved for {len(displacement_magnitudes)} displacement values in parallel using vmap"
)
INFO     | Building optimization problem with 18 terms and 9 variables: 15 costs, 3 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing constraint group with 2 constraints (constraint_eq_zero), 1 variables each: augmented_pin_support
INFO     | Vectorizing group with 15 costs, 2 variables each: bar_strain_energy
INFO     | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_prescribed_displacement
Solved for 21 displacement values in parallel using vmap

Hide code cell source

# Create animated visualization with force coloring.
scale_anim = 10  # Displacement magnification


def get_member_color(force: float, max_force: float) -> str:
    """Get color for a member based on its force (red=tension, blue=compression)."""
    intensity = min(abs(force) / (max_force + 1e-6), 1.0)
    if force > 0:  # Tension
        return f"rgba(220, {int(80 + 175 * (1 - intensity))}, {int(80 + 175 * (1 - intensity))}, 1)"
    else:  # Compression
        return f"rgba({int(80 + 175 * (1 - intensity))}, {int(80 + 175 * (1 - intensity))}, 220, 1)"


# Compute member forces for all displacement configurations.
def compute_forces_for_disp(disp: jax.Array) -> jax.Array:
    """Compute all member forces for a given displacement field."""
    return jax.vmap(lambda m: compute_member_force(m[0], m[1], disp))(members)


all_forces = jax.vmap(compute_forces_for_disp)(all_displacements)
global_max_force = float(jnp.max(jnp.abs(all_forces)))

# Build frames for animation.
frames = []
for i, (disp_mag, disp, forces) in enumerate(
    zip(displacement_magnitudes, all_displacements, all_forces)
):
    scaled_pos = node_positions + scale_anim * disp

    # Create individual traces for each member (for per-member coloring)
    member_traces = []
    for idx, m in enumerate(members):
        mi, mj = int(m[0]), int(m[1])
        color = get_member_color(float(forces[idx]), global_max_force)
        member_traces.append(
            go.Scatter(
                x=[float(scaled_pos[mi, 0]), float(scaled_pos[mj, 0])],
                y=[float(scaled_pos[mi, 1]), float(scaled_pos[mj, 1])],
                mode="lines",
                line=dict(color=color, width=4),
                showlegend=False,
            )
        )

    # Node markers.
    node_trace = go.Scatter(
        x=[float(p) for p in scaled_pos[:, 0]],
        y=[float(p) for p in scaled_pos[:, 1]],
        mode="markers",
        marker=dict(size=8, color="steelblue"),
        showlegend=False,
    )

    deflection = float(-disp[load_node_id, 1]) * 1000
    frames.append(
        go.Frame(
            data=member_traces + [node_trace],
            name=str(i),
            layout=go.Layout(
                title=f"Prescribed Displacement: {float(disp_mag) * 1000:.1f} mm, "
                f"Actual: {deflection:.1f} mm"
            ),
        )
    )

# Initial frame data (first displacement = 0, so all gray)
init_pos = node_positions + scale_anim * all_displacements[0]
init_forces = all_forces[0]

init_traces = []
for idx, m in enumerate(members):
    mi, mj = int(m[0]), int(m[1])
    color = get_member_color(float(init_forces[idx]), global_max_force)
    init_traces.append(
        go.Scatter(
            x=[float(init_pos[mi, 0]), float(init_pos[mj, 0])],
            y=[float(init_pos[mi, 1]), float(init_pos[mj, 1])],
            mode="lines",
            line=dict(color=color, width=4),
            showlegend=False,
        )
    )

init_traces.append(
    go.Scatter(
        x=[float(p) for p in init_pos[:, 0]],
        y=[float(p) for p in init_pos[:, 1]],
        mode="markers",
        marker=dict(size=8, color="steelblue"),
        showlegend=False,
    )
)

fig_anim = go.Figure(
    data=init_traces,
    frames=frames,
    layout=go.Layout(
        title="Prescribed Displacement: 0.0 mm, Actual: 0.0 mm",
        xaxis=dict(title="x [m]", range=[-1, span + 1], scaleanchor="y", scaleratio=1),
        yaxis=dict(title="y [m]", range=[-1, height + 2]),
        height=450,
        margin=dict(b=100),
        updatemenus=[
            dict(
                type="buttons",
                showactive=False,
                y=0,
                x=0.1,
                xanchor="right",
                buttons=[
                    dict(
                        label="Play",
                        method="animate",
                        args=[
                            None,
                            dict(
                                frame=dict(duration=100, redraw=True), fromcurrent=True
                            ),
                        ],
                    ),
                    dict(
                        label="Pause",
                        method="animate",
                        args=[
                            [None],
                            dict(
                                frame=dict(duration=0, redraw=False), mode="immediate"
                            ),
                        ],
                    ),
                ],
            )
        ],
        sliders=[
            dict(
                active=0,
                steps=[
                    dict(
                        args=[
                            [str(i)],
                            dict(frame=dict(duration=0, redraw=True), mode="immediate"),
                        ],
                        label=f"{float(d) * 1000:.0f}",
                        method="animate",
                    )
                    for i, d in enumerate(displacement_magnitudes)
                ],
                x=0.1,
                len=0.8,
                xanchor="left",
                y=-0.15,
                currentvalue=dict(
                    prefix="Displacement (mm): ", visible=True, xanchor="center"
                ),
                transition=dict(duration=0),
            )
        ],
    ),
)

# Add color legend.
fig_anim.add_annotation(
    x=0.02,
    y=0.98,
    xref="paper",
    yref="paper",
    text="Red = Tension, Blue = Compression",
    showarrow=False,
    font=dict(size=10),
    align="left",
    bgcolor="white",
    bordercolor="gray",
    borderwidth=1,
)

HTML(fig_anim.to_html(full_html=False, include_plotlyjs="cdn", auto_play=False))