What’s traced?#

JAX’s JIT compilation requires distinguishing between traced values (that can change between calls) and static values (that trigger recompilation if changed).

This page discusses:

  • What’s traced vs static in jaxls

  • Why variable IDs are traced

  • What triggers recompilation

  • Examples: leveraging traced values

Background: JAX tracing#

When JAX JIT-compiles a function, it traces the computation with abstract values. Values can be:

  • Traced: Represented as abstract shapes/dtypes during compilation. The actual values are substituted at runtime. Changing traced values doesn’t trigger recompilation.

  • Static: Baked into the compiled code. Changing static values triggers a new compilation.

jaxls’s static/traced split#

Traced (can change without recompilation)#

Value

Why traced

Variable IDs (Var.id)

Different variable subsets use same code

Variable values (VarValues)

Values change during optimization

Jacobian values

Computed from current variable values

Solver parameters (LM damping, CG tolerance)

May adapt during solve

Augmented Lagrangian multipliers/penalties

Updated between outer iterations

Static (changes trigger recompilation)#

Value

Why static

Problem dimensions (_tangent_dim, _residual_dim)

Determines array shapes

Cost counts and structure (_cost_counts)

Determines vectorization structure

Solver choice (linear_solver, sparse_mode)

Different code paths

Variable tangent dimensions

Determines Jacobian block sizes

Constraint types (equality vs inequality)

Different augmented Lagrangian update logic

Why variable IDs are traced#

Var.id is a traced jax.Array, not a static int. This enables automatic vectorization of costs. These two approaches produce equivalent results after analysis:

# Approach A: Array of IDs (explicitly batched).
costs_a = [
    pairwise_cost(
        MyVar(id=jnp.arange(100)),      # IDs 0-99
        MyVar(id=jnp.arange(100) + 1),  # IDs 1-100
        data,
    )
]

# Approach B: List comprehension (implicitly batched).
costs_b = [
    pairwise_cost(MyVar(id=i), MyVar(id=i + 1), data[i])
    for i in range(100)
]

In approach A, the IDs are explicit arrays. In approach B, jaxls stacks the scalar IDs into arrays during analyze(). In both cases, the analyzed problem vmaps over the ID arrays to compute residuals and Jacobians in parallel.

This also works with inline lambda functions:

# Approach C: Inline lambdas with list comprehension.
costs_c = [
    jaxls.Cost.factory(
        lambda vals, v: vals[v] - target[i]
    )(MyVar(id=i))
    for i in range(100)
]

Even though this creates 100 separate lambda objects, jaxls recognizes them as the same cost type using bytecode analysis: dis.Bytecode extracts the instruction sequence, which is identical across all 100 lambdas.

Examples: leveraging traced values#

Different variable assignments#

The same cost structure can connect different variables without recompilation. jax.jit() will only compile the first call:

@jax.jit
def solve_with_ids(
    var_ids: jax.Array,  # Which variables to use.
    targets: jax.Array,
) -> jaxls.VarValues:
    # Same cost type, but connecting different variable IDs.
    costs = [prior_cost(MyVar(id=var_ids), targets)]
    problem = jaxls.LeastSquaresProblem(costs).analyze()
    return problem.solve(initial_vals).vals

# First call compiles; subsequent calls reuse compiled code.
solution1 = solve_with_ids(jnp.array([0, 1, 2]), targets_a)
solution2 = solve_with_ids(jnp.array([3, 4, 5]), targets_b)  # No recompilation.

Batched solves with vmap#

With jax.vmap(), solve multiple problem instances in parallel by vmapping over initial values:

def solve_one(init_vals: jaxls.VarValues) -> jaxls.VarValues:
    return problem.solve(init_vals).vals

# Solve 100 problems in parallel.
batched_solutions = jax.jit(jax.vmap(solve_one))(batched_init_vals)

Sequential solves with scan#

With jax.lax.scan(), solve a sequence of problems where each uses the previous solution as its initial guess:

def solve_step(vals: jaxls.VarValues, target: jax.Array) -> tuple[jaxls.VarValues, jax.Array]:
    costs = [prior_cost(MyVar(id=jnp.arange(n)), target)]
    problem = jaxls.LeastSquaresProblem(costs).analyze()
    new_vals = problem.solve(vals).vals
    return new_vals, new_vals[MyVar(id=jnp.arange(n))]

# Solve for each target in sequence, warm-starting from previous solution.
final_vals, trajectory = jax.lax.scan(solve_step, initial_vals, targets_sequence)

Sweeping solver parameters#

Traced values can be sweeped over in parallel with jax.vmap():

def solve_with_damping(damping: float) -> jaxls.VarValues:
    return problem.solve(
        initial_vals,
        trust_region=jaxls.TrustRegionConfig(lambda_initial=damping),
    ).vals

# Sweep over damping values.
solutions = jax.jit(jax.vmap(solve_with_damping))(jnp.logspace(-3, 3, 10))