What’s traced?#
JAX’s JIT compilation requires distinguishing between traced values (that can change between calls) and static values (that trigger recompilation if changed).
This page discusses:
What’s traced vs static in jaxls
Why variable IDs are traced
What triggers recompilation
Examples: leveraging traced values
Background: JAX tracing#
When JAX JIT-compiles a function, it traces the computation with abstract values. Values can be:
Traced: Represented as abstract shapes/dtypes during compilation. The actual values are substituted at runtime. Changing traced values doesn’t trigger recompilation.
Static: Baked into the compiled code. Changing static values triggers a new compilation.
jaxls’s static/traced split#
Traced (can change without recompilation)#
Value |
Why traced |
|---|---|
Variable IDs ( |
Different variable subsets use same code |
Variable values ( |
Values change during optimization |
Jacobian values |
Computed from current variable values |
Solver parameters (LM damping, CG tolerance) |
May adapt during solve |
Augmented Lagrangian multipliers/penalties |
Updated between outer iterations |
Static (changes trigger recompilation)#
Value |
Why static |
|---|---|
Problem dimensions ( |
Determines array shapes |
Cost counts and structure ( |
Determines vectorization structure |
Solver choice ( |
Different code paths |
Variable tangent dimensions |
Determines Jacobian block sizes |
Constraint types (equality vs inequality) |
Different augmented Lagrangian update logic |
Why variable IDs are traced#
Var.id is a traced jax.Array, not a static int. This enables automatic
vectorization of costs. These two approaches produce equivalent results after analysis:
# Approach A: Array of IDs (explicitly batched).
costs_a = [
pairwise_cost(
MyVar(id=jnp.arange(100)), # IDs 0-99
MyVar(id=jnp.arange(100) + 1), # IDs 1-100
data,
)
]
# Approach B: List comprehension (implicitly batched).
costs_b = [
pairwise_cost(MyVar(id=i), MyVar(id=i + 1), data[i])
for i in range(100)
]
In approach A, the IDs are explicit arrays. In approach B, jaxls stacks the
scalar IDs into arrays during analyze(). In both cases, the analyzed problem
vmaps over the ID arrays to compute residuals and Jacobians in parallel.
This also works with inline lambda functions:
# Approach C: Inline lambdas with list comprehension.
costs_c = [
jaxls.Cost.factory(
lambda vals, v: vals[v] - target[i]
)(MyVar(id=i))
for i in range(100)
]
Even though this creates 100 separate lambda objects, jaxls recognizes them as
the same cost type using bytecode analysis: dis.Bytecode extracts the instruction
sequence, which is identical across all 100 lambdas.
Examples: leveraging traced values#
Different variable assignments#
The same cost structure can connect different variables without recompilation.
jax.jit() will only compile the first call:
@jax.jit
def solve_with_ids(
var_ids: jax.Array, # Which variables to use.
targets: jax.Array,
) -> jaxls.VarValues:
# Same cost type, but connecting different variable IDs.
costs = [prior_cost(MyVar(id=var_ids), targets)]
problem = jaxls.LeastSquaresProblem(costs).analyze()
return problem.solve(initial_vals).vals
# First call compiles; subsequent calls reuse compiled code.
solution1 = solve_with_ids(jnp.array([0, 1, 2]), targets_a)
solution2 = solve_with_ids(jnp.array([3, 4, 5]), targets_b) # No recompilation.
Batched solves with vmap#
With jax.vmap(), solve multiple problem instances in parallel by vmapping
over initial values:
def solve_one(init_vals: jaxls.VarValues) -> jaxls.VarValues:
return problem.solve(init_vals).vals
# Solve 100 problems in parallel.
batched_solutions = jax.jit(jax.vmap(solve_one))(batched_init_vals)
Sequential solves with scan#
With jax.lax.scan(), solve a sequence of problems where each uses the
previous solution as its initial guess:
def solve_step(vals: jaxls.VarValues, target: jax.Array) -> tuple[jaxls.VarValues, jax.Array]:
costs = [prior_cost(MyVar(id=jnp.arange(n)), target)]
problem = jaxls.LeastSquaresProblem(costs).analyze()
new_vals = problem.solve(vals).vals
return new_vals, new_vals[MyVar(id=jnp.arange(n))]
# Solve for each target in sequence, warm-starting from previous solution.
final_vals, trajectory = jax.lax.scan(solve_step, initial_vals, targets_sequence)
Sweeping solver parameters#
Traced values can be sweeped over in parallel with jax.vmap():
def solve_with_damping(damping: float) -> jaxls.VarValues:
return problem.solve(
initial_vals,
trust_region=jaxls.TrustRegionConfig(lambda_initial=damping),
).vals
# Sweep over damping values.
solutions = jax.jit(jax.vmap(solve_with_damping))(jnp.logspace(-3, 3, 10))