Typed API#

jaxls aims to provide a typed API for nonlinear least squares. Type-safe APIs are easier to use: IDE autocomplete works, type checkers catch errors before runtime, and the code is self-documenting. You’ll usually see better results if you’re writing code with the help of an LLM, which you can ask to inspect the API or run a type checker automatically.

This page discusses:

  • How jaxls’s decorator-based approach preserves function signatures

  • How generic variables (Var[T]) preserve value types

  • Python 3.10 compatibility

Type-safe decorators#

jaxls’s cost API is designed around two recent Python type system constructs: variadic generics (PEP 646) and parameter specification variables (PEP 612). This enables decorator-based cost definition that’s fully typed.

Before and after the decorator#

The @jaxls.Cost.factory decorator transforms a residual function into a cost factory:

# Before: residual function
# Signature: (VarValues, SE3Var, jax.Array) -> jax.Array
def pose_cost(
    vals: jaxls.VarValues,
    var: jaxls.SE3Var,
    target: jax.Array,
) -> jax.Array:
    return (vals[var].inverse() @ jaxlie.SE3(target)).log()

# After: cost factory (decorated)
# Signature: (SE3Var, jax.Array) -> Cost
@jaxls.Cost.factory
def pose_cost(
    vals: jaxls.VarValues,
    var: jaxls.SE3Var,
    target: jax.Array,
) -> jax.Array:
    return (vals[var].inverse() @ jaxlie.SE3(target)).log()

After decoration, calling pose_cost(my_var, my_target) returns a Cost object. The VarValues is provided later during optimization.

Implementation: variadic generics#

The important type aliases for implementing the syntax above are:

# A residual function takes VarValues + arbitrary args, returns an array.
type ResidualFunc[**Args] = Callable[
    Concatenate[VarValues, Args],
    jax.Array,
]

# A cost factory takes arbitrary args (no VarValues), returns a Cost.
type CostFactory[**Args] = Callable[
    Args,
    Cost,
]

The **Args syntax captures all parameters after VarValues as a type tuple. Concatenate[VarValues, Args] prepends VarValues to form the full signature.

The decorator’s type signature expresses the transformation:

@staticmethod
def factory[**Args_](
    compute_residual: ResidualFunc[Args_],
) -> CostFactory[Args_]:
    ...

This says: given a residual function with signature (VarValues, *Args) -> Array, return a factory with signature (*Args) -> Cost.

Type-safe variables#

Custom variable types are defined by arguments to __init_subclass__:

# Simple Euclidean variable (tangent_dim inferred from default shape)
class PointVar(
    jaxls.Var[jax.Array],
    default_factory=lambda: jnp.zeros(3),
):
    """A 3D point variable."""

# manifold variable (explicit tangent_dim and retraction)
class SO3Var(
    jaxls.Var[jaxlie.SO3],
    default_factory=jaxlie.SO3.identity,
    retract_fn=jaxlie.manifold.rplus,
    tangent_dim=3,
):
    """An SO(3) rotation variable."""

The generic Var[T] class preserves value types throughout the API:

# SE3Var subclasses Var[jaxlie.SE3].
var = jaxls.SE3Var(id=0)

# vals[var] returns jaxlie.SE3, not a generic array.
pose = vals[var]  # Type checker knows this is SE3.

# IDE autocomplete works for SE3 methods
position = pose.translation()  # Autocomplete suggests .translation(), .rotation(), etc.

This is implemented via VarValues.__getitem__:

def __getitem__[T](self, var: Var[T]) -> T:
    """Get value for a variable, preserving type."""
    ...

The generic parameter T flows from the variable definition through to the return type.

Python 3.10 compatibility#

The full typing features used by jaxls require Python 3.12+. For Python 3.10/3.11 compatibility, jaxls includes a transpiler (transpile_py310.py) that:

  1. Removes PEP 695 syntax (type aliases, [T] on classes)

  2. Strips type annotations that would fail at runtime

  3. Generates compatible code in src/jaxls/_py310/

The 3.10 version works correctly at runtime but loses some static type information.