# SMPL-H shape fitting

Optimizing SMPL-H body shape parameters to achieve a target height.

**Inputs:** Target height, initial shape parameters (zeros)  
**Outputs:** Shape (beta) parameters that produce a body with the desired height

Features used:
- {class}`~jaxls.Var` for shape parameters
- {func}`@jaxls.Cost.factory <jaxls.Cost.factory>` with constraint for height
- Augmented Lagrangian solver for constrained optimization

Note: the SMPL-H implementation here is minimal. For full-featured SMPL models in `jaxls`, see [egoallo](https://github.com/brentyi/egoallo) or [VideoMimic](https://github.com/hongsukchoi/videomimic).

In [1]:
import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");

In [2]:
import io
import pathlib
import urllib.request
import zipfile

import jax
import jax.numpy as jnp
import jax_dataclasses as jdc
import jaxls
import numpy as np
from jax import Array

## Download SMPL-H model

The SMPL-H model represents human body shape using a low-dimensional parameterization. Shape variations are controlled by beta parameters that deform a template mesh.

In [3]:
# Download SMPL-H model if not already present.
smplh_path = pathlib.Path("/tmp/SMPLH_NEUTRAL.npz")

if not smplh_path.exists():
    print("Downloading SMPL-H model...")
    url = "https://brentyi.github.io/viser-example-assets/SMPLH_NEUTRAL.zip"
    with urllib.request.urlopen(url) as response:
        zip_data = io.BytesIO(response.read())
    with zipfile.ZipFile(zip_data) as zf:
        zf.extractall("/tmp")
    print(f"Downloaded to {smplh_path}")
else:
    print(f"Using cached model at {smplh_path}")

Using cached model at /tmp/SMPLH_NEUTRAL.npz


## SMPL-H model implementation

A minimal implementation of the SMPL-H body model. Shape is controlled by beta parameters, which are PCA coefficients that linearly combine learned shape basis vectors to deform the template mesh.

In [4]:
@jdc.pytree_dataclass
class SmplhModel:
    """SMPL-H human body model."""

    faces: Array
    """Vertex indices for mesh faces, shape (faces, 3)."""
    v_template: Array
    """Template mesh vertices, shape (verts, 3)."""
    shapedirs: Array
    """Shape blend shape bases, shape (verts, 3, n_betas)."""

    @staticmethod
    def load(npz_path: pathlib.Path) -> "SmplhModel":
        """Load model from .npz file."""
        params = np.load(npz_path, allow_pickle=True)
        return SmplhModel(
            faces=jnp.array(params["f"].astype(np.int32)),
            v_template=jnp.array(params["v_template"].astype(np.float32)),
            shapedirs=jnp.array(params["shapedirs"].astype(np.float32)),
        )

    def get_vertices(self, betas: Array) -> Array:
        """Compute mesh vertices for given shape parameters."""
        num_betas = betas.shape[0]
        # Apply shape blend shapes: v = v_template + shapedirs @ betas.
        return self.v_template + jnp.einsum(
            "vxb,b->vx", self.shapedirs[:, :, :num_betas], betas
        )

    def get_height(self, betas: Array) -> Array:
        """Compute body height from min to max vertex z-coordinate."""
        verts = self.get_vertices(betas)
        # Height is the range of the y-coordinate (SMPL uses y-up).
        return jnp.max(verts[:, 1]) - jnp.min(verts[:, 1])

In [5]:
# Load the model.
model = SmplhModel.load(smplh_path)

# Check the template (zero-beta) height.
template_height = float(model.get_height(jnp.zeros(16)))
print(
    f"Template mesh: {model.v_template.shape[0]} vertices, {model.faces.shape[0]} faces"
)
print(f"Template height (beta=0): {template_height:.3f} m")

Template mesh: 6890 vertices, 13776 faces
Template height (beta=0): 1.717 m


## Problem setup

We optimize the first 10 beta parameters to achieve a target height of 2.0 meters (tall), while regularizing betas toward zero to maintain a natural body shape.

In [6]:
# Target height in meters.
TARGET_HEIGHT = 2.0
NUM_BETAS = 10


# Variable for shape parameters.
class BetaVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros(NUM_BETAS)):
    """SMPL-H beta (shape) parameters."""


beta_var = BetaVar(id=0)


@jaxls.Cost.factory(kind="constraint_eq_zero")
def height_constraint(
    vals: jaxls.VarValues,
    var: BetaVar,
    model: SmplhModel,
    target_height: float,
) -> jax.Array:
    """Constrain body height to target value."""
    betas = vals[var]
    current_height = model.get_height(betas)
    # jaxls accepts scalar residuals.
    return current_height - target_height


@jaxls.Cost.factory
def beta_regularization(
    vals: jaxls.VarValues,
    var: BetaVar,
    weight: float,
) -> jax.Array:
    """Regularize betas toward zero for natural shapes."""
    return weight * vals[var]

## Solving

When constraints are present, jaxls automatically uses an Augmented Lagrangian method. The solver iteratively adjusts Lagrange multipliers and penalty parameters to satisfy the constraint.

In [7]:
# Build the optimization problem.
costs: list[jaxls.Cost] = [
    height_constraint(beta_var, model, TARGET_HEIGHT),
    beta_regularization(beta_var, weight=0.5),
]

# Initial values: zeros.
initial_betas = jnp.zeros(NUM_BETAS)
initial_vals = jaxls.VarValues.make([beta_var.with_value(initial_betas)])

# Build the problem.
problem = jaxls.LeastSquaresProblem(costs, [beta_var])

# Visualize the problem structure structure.
problem.show()

In [8]:
# Analyze the problem and print info.
problem = problem.analyze()

print(f"Initial height: {model.get_height(initial_betas):.3f} m")
print(f"Target height: {TARGET_HEIGHT:.3f} m")

[1mINFO    [0m | Building optimization problem with 2 terms and 1 variables: 1 costs, 1 eq_zero, 0 leq_zero, 0 geq_zero
[1mINFO    [0m | Vectorizing group with 1 costs, 1 variables each: beta_regularization
[1mINFO    [0m | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_height_constraint
Initial height: 1.717 m
Target height: 2.000 m


In [9]:
# Solve. Augmented Lagrangian is used automatically for constrained problems.
solution = problem.solve(
    initial_vals,
    linear_solver="dense_cholesky",
    termination=jaxls.TerminationConfig(cost_tolerance=1e-8),
)

optimized_betas = solution[beta_var]
final_height = model.get_height(optimized_betas)

print(f"\nOptimized height: {float(final_height):.4f} m")
print(f"Height error: {abs(float(final_height) - TARGET_HEIGHT) * 100:.2f} cm")
print(f"Beta norm: {float(jnp.linalg.norm(optimized_betas)):.3f}")

[1mINFO    [0m | Augmented Lagrangian: initial snorm=2.8264e-01, csupn=2.8264e-01, max_rho=1.0000e+01, constraint_dim=1
[1mINFO    [0m |  step #0: cost=0.0000 lambd=0.0005
[1mINFO    [0m |      - beta_regularization(1): 0.00000 (avg 0.00000)
[1mINFO    [0m |      - augmented_height_constraint(1): 0.79885 (avg 0.79885)
[1mINFO    [0m |      accepted=True ATb_norm=4.61e-01 cost_prev=0.7988 cost_new=0.5705
[1mINFO    [0m |  step #1: cost=0.1629 lambd=0.0003
[1mINFO    [0m |      - beta_regularization(1): 0.16292 (avg 0.01629)
[1mINFO    [0m |      - augmented_height_constraint(1): 0.40754 (avg 0.40754)
[1mINFO    [0m |  step #2: cost=0.1629 lambd=0.0005
[1mINFO    [0m |      - beta_regularization(1): 0.16292 (avg 0.01629)
[1mINFO    [0m |      - augmented_height_constraint(1): 0.40754 (avg 0.40754)
[1mINFO    [0m |  step #3: cost=0.1629 lambd=0.0010
[1mINFO    [0m |      - beta_regularization(1): 0.16292 (avg 0.01629)
[1mINFO    [0m |      - augmented_height_co

## Visualization

Compare the template mesh (beta=0) with the optimized shape side by side.

In [None]:
import contextlib
import io
import viser

# Get vertices for both configurations.
initial_verts = np.array(model.get_vertices(initial_betas))
optimized_verts = np.array(model.get_vertices(optimized_betas))
faces = np.array(model.faces)

# Compute heights for labels.
initial_height = float(model.get_height(initial_betas))
optimized_height = float(final_height)

# Compute y offsets to align feet at ground level (SMPL uses y-up).
initial_y_offset = -initial_verts[:, 1].min()
optimized_y_offset = -optimized_verts[:, 1].min()

# Offset for side-by-side placement.
x_offset = 0.8

# Create Viser server (suppress output).
try:
    server.stop()
except NameError:
    pass
with (
    contextlib.redirect_stdout(io.StringIO()),
    contextlib.redirect_stderr(io.StringIO()),
):
    server = viser.ViserServer(verbose=False)

# Use y-up to match SMPL coordinates directly.
server.scene.set_up_direction("+y")

# Set initial camera position in front of the figures (SMPL faces +z).
server.initial_camera.position = (0.0, 2.0, 4.0)
server.initial_camera.look_at = (0.0, 0.9, 0.0)

# Add ground grid (xz plane for y-up) with fade.
server.scene.add_grid(
    "/ground",
    width=4.0,
    height=2.0,
    plane="xz",
    infinite_grid=True,
    fade_distance=10.0,
    cell_color=(200, 200, 200),
    section_color=(170, 170, 170),
)

# Add initial mesh (template, on the right).
server.scene.add_mesh_simple(
    "/initial_mesh",
    vertices=initial_verts + np.array([[-x_offset, initial_y_offset, 0.0]]),
    faces=faces,
    color=(70, 130, 180),  # Steel blue
    flat_shading=False,
)

# Add optimized mesh (tall figure, on the left).
server.scene.add_mesh_simple(
    "/optimized_mesh",
    vertices=optimized_verts + np.array([[x_offset, optimized_y_offset, 0.0]]),
    faces=faces,
    color=(34, 139, 34),  # Forest green
    flat_shading=False,
)

# Add height labels above each mesh.
server.scene.add_label(
    "/initial_label",
    text=f"Template: {initial_height:.2f}m",
    position=(-x_offset, initial_height + 0.15, 0.0),
    anchor="bottom-center",
)
server.scene.add_label(
    "/optimized_label",
    text=f"Optimized: {optimized_height:.2f}m",
    position=(x_offset, optimized_height + 0.15, 0.0),
    anchor="bottom-center",
)

# Display inline in the notebook.
server.scene.show(height=500)

The optimization finds shape parameters that satisfy the height constraint while keeping the body shape natural (small beta norm). The regularization prevents extreme deformations that could produce unrealistic body shapes.

For more on constrained optimization, see {doc}`/guide/advanced/constraints`.