Typed API#
jaxls aims to provide a typed API for nonlinear least squares. Type-safe APIs are easier to use: IDE autocomplete works, type checkers catch errors before runtime, and the code is self-documenting. You’ll usually see better results if you’re writing code with the help of an LLM, which you can ask to inspect the API or run a type checker automatically.
This page discusses:
How jaxls’s decorator-based approach preserves function signatures
How generic variables (
Var[T]) preserve value typesPython 3.10 compatibility
Type-safe decorators#
jaxls’s cost API is designed around two recent Python type system constructs:
variadic generics (PEP 646) and parameter
specification variables (PEP 612). This
enables decorator-based cost definition that’s fully typed.
Before and after the decorator#
The @jaxls.Cost.factory decorator transforms a residual function into a
cost factory:
# Before: residual function
# Signature: (VarValues, SE3Var, jax.Array) -> jax.Array
def pose_cost(
vals: jaxls.VarValues,
var: jaxls.SE3Var,
target: jax.Array,
) -> jax.Array:
return (vals[var].inverse() @ jaxlie.SE3(target)).log()
# After: cost factory (decorated)
# Signature: (SE3Var, jax.Array) -> Cost
@jaxls.Cost.factory
def pose_cost(
vals: jaxls.VarValues,
var: jaxls.SE3Var,
target: jax.Array,
) -> jax.Array:
return (vals[var].inverse() @ jaxlie.SE3(target)).log()
After decoration, calling pose_cost(my_var, my_target) returns a Cost object.
The VarValues is provided later during optimization.
Implementation: variadic generics#
The important type aliases for implementing the syntax above are:
# A residual function takes VarValues + arbitrary args, returns an array.
type ResidualFunc[**Args] = Callable[
Concatenate[VarValues, Args],
jax.Array,
]
# A cost factory takes arbitrary args (no VarValues), returns a Cost.
type CostFactory[**Args] = Callable[
Args,
Cost,
]
The **Args syntax captures all parameters after VarValues as a type tuple.
Concatenate[VarValues, Args] prepends VarValues to form the full signature.
The decorator’s type signature expresses the transformation:
@staticmethod
def factory[**Args_](
compute_residual: ResidualFunc[Args_],
) -> CostFactory[Args_]:
...
This says: given a residual function with signature (VarValues, *Args) -> Array,
return a factory with signature (*Args) -> Cost.
Type-safe variables#
Custom variable types are defined by arguments to __init_subclass__:
# Simple Euclidean variable (tangent_dim inferred from default shape)
class PointVar(
jaxls.Var[jax.Array],
default_factory=lambda: jnp.zeros(3),
):
"""A 3D point variable."""
# manifold variable (explicit tangent_dim and retraction)
class SO3Var(
jaxls.Var[jaxlie.SO3],
default_factory=jaxlie.SO3.identity,
retract_fn=jaxlie.manifold.rplus,
tangent_dim=3,
):
"""An SO(3) rotation variable."""
The generic Var[T] class preserves value types throughout the API:
# SE3Var subclasses Var[jaxlie.SE3].
var = jaxls.SE3Var(id=0)
# vals[var] returns jaxlie.SE3, not a generic array.
pose = vals[var] # Type checker knows this is SE3.
# IDE autocomplete works for SE3 methods
position = pose.translation() # Autocomplete suggests .translation(), .rotation(), etc.
This is implemented via VarValues.__getitem__:
def __getitem__[T](self, var: Var[T]) -> T:
"""Get value for a variable, preserving type."""
...
The generic parameter T flows from the variable definition through to the return type.
Python 3.10 compatibility#
The full typing features used by jaxls require Python 3.12+. For Python 3.10/3.11 compatibility,
jaxls includes a transpiler (transpile_py310.py) that:
Removes PEP 695 syntax (
typealiases,[T]on classes)Strips type annotations that would fail at runtime
Generates compatible code in
src/jaxls/_py310/
The 3.10 version works correctly at runtime but loses some static type information.