Core API#

Problem#

class jaxls.LeastSquaresProblem[source]#

Least squares problems are bipartite graphs with two types of nodes:

  • Cost: cost terms or constraints.

    • kind="l2_squared" (default): Minimize squared L2 norm ||r(x)||^2

    • kind="constraint_eq_zero": Equality constraint r(x) = 0

    • kind="constraint_leq_zero": Inequality constraint r(x) <= 0

    • kind="constraint_geq_zero": Inequality constraint r(x) >= 0

  • Var: the parameters we want to optimize.

costs: Iterable[Cost]#
variables: Iterable[Var]#
show(*, width: int = 800, height: int = 500, max_costs: int = 1000, max_variables: int = 500) None[source]#

Display an interactive graph showing costs and variables.

In Jupyter/JupyterLab/VS Code notebooks, displays inline. Otherwise, opens in the default web browser.

Parameters:
  • width (int) – Maximum width of the visualization in pixels.

  • height (int) – Height of the visualization in pixels.

  • max_costs (int) – Maximum number of cost nodes to show. When multiple cost types exist, the limit is distributed proportionally across types.

  • max_variables (int) – Maximum number of variables per type to show. Only costs where all variables are visible are shown.

Return type:

None

analyze(use_onp: bool = False) AnalyzedLeastSquaresProblem[source]#

Analyze sparsity pattern of least squares problem. Needed before solving.

Processes all costs and variables to compute the sparse Jacobian structure, group costs by structure for vectorization, and prepare for optimization.

Parameters:

use_onp (bool) – If True, use numpy instead of jax.numpy for index computations. Can be faster for problem setup on CPU.

Returns:

An AnalyzedLeastSquaresProblem ready for solving.

Return type:

AnalyzedLeastSquaresProblem

class jaxls.AnalyzedLeastSquaresProblem[source]#

AnalyzedLeastSquaresProblem(_stacked_costs: ‘tuple[_AnalyzedCost, …]’, _cost_counts: ‘jdc.Static[tuple[int, …]]’, _sorted_ids_from_var_type: ‘dict[type[Var], jax.Array]’, _jac_coords_coo: ‘SparseCooCoordinates’, _jac_coords_csr: ‘SparseCsrCoordinates’, _tangent_ordering: ‘jdc.Static[VarTypeOrdering]’, _tangent_start_from_var_type: ‘jdc.Static[dict[type[Var[Any]], int]]’, _tangent_dim: ‘jdc.Static[int]’, _residual_dim: ‘jdc.Static[int]’)

solve(initial_vals: VarValues | None = None, *, linear_solver: Literal['conjugate_gradient', 'cholmod', 'dense_cholesky'] | ConjugateGradientConfig = 'conjugate_gradient', trust_region: TrustRegionConfig | None = TrustRegionConfig(), termination: TerminationConfig = TerminationConfig(), sparse_mode: Literal['blockrow', 'coo', 'csr'] = 'blockrow', verbose: bool = True, augmented_lagrangian: AugmentedLagrangianConfig | None = None, return_summary: Literal[False] = False) VarValues[source]#
solve(initial_vals: VarValues | None = None, *, linear_solver: Literal['conjugate_gradient', 'cholmod', 'dense_cholesky'] | ConjugateGradientConfig = 'conjugate_gradient', trust_region: TrustRegionConfig | None = TrustRegionConfig(), termination: TerminationConfig = TerminationConfig(), sparse_mode: Literal['blockrow', 'coo', 'csr'] = 'blockrow', verbose: bool = True, augmented_lagrangian: AugmentedLagrangianConfig | None = None, return_summary: Literal[True]) tuple[VarValues, SolveSummary]

Solve the nonlinear least squares problem using either Gauss-Newton or Levenberg-Marquardt.

For constrained problems (with equality constraints), the Augmented Lagrangian method will be automatically used.

Parameters:
  • initial_vals – Initial values for the variables. If None, default values will be used.

  • linear_solver – The linear solver to use.

  • trust_region – Configuration for Levenberg-Marquardt trust region.

  • termination – Configuration for termination criteria.

  • sparse_mode – The representation to use for sparse matrix multiplication. Can be “blockrow”, “coo”, or “csr”.

  • verbose – Whether to print verbose output during optimization.

  • augmented_lagrangian – Configuration for Augmented Lagrangian method. Only used if constraints are present. If None and constraints exist, a default configuration will be used.

  • return_summary – If True, return a summary of the solve.

Returns:

Optimized variable values.

compute_residual_vector(vals: VarValues) Array[source]#

Compute the residual vector. The cost we are optimizing is defined as the sum of squared terms within this vector.

Parameters:

vals (VarValues)

Return type:

Array

compute_constraint_values(vals: VarValues) Array[source]#

Compute all constraint values as a flat array.

For equality constraints, these should all be zero at a feasible solution. For inequality constraints (g(x) <= 0), these should be <= 0.

Parameters:

vals (VarValues)

Return type:

Array

make_covariance_estimator(vals: VarValues, method: Literal['cholmod_spinv'] | LinearSolverCovarianceEstimatorConfig | None = None, *, scale_by_residual_variance: bool = False) CovarianceEstimator[source]#

Create a covariance estimator for uncertainty quantification.

This computes blocks of the covariance matrix (J^T J)^{-1}, which represents the uncertainty of estimated variables at the solution. The covariance is computed in the tangent space, appropriate for manifold variables like SE3, SO3, etc.

Parameters:
  • vals (VarValues) – Variable values at which to compute covariance (typically the solution from solve()).

  • method (Literal['cholmod_spinv'] | LinearSolverCovarianceEstimatorConfig | None) –

    Covariance computation method. Options: - None (default): Use CG with block-Jacobi preconditioning.

    GPU-friendly and adapts to problem structure.

    • LinearSolverCovarianceEstimatorConfig: Custom linear solver config.

    • ”cholmod_spinv”: Use CHOLMOD’s sparse inverse. Fast extraction but requires sksparse and only includes entries in the sparsity pattern.

  • scale_by_residual_variance (bool) – If True, scale by the estimated residual variance sigma^2 = ||r||^2 / (m - n), where m is the number of residuals and n is the tangent dimension.

Returns:

A CovarianceEstimator that can compute covariance blocks via estimator.covariance(var0, var1).

Return type:

CovarianceEstimator

class jaxls.SolveSummary[source]#

SolveSummary(iterations: ‘jax.Array’, termination_criteria: ‘jax.Array’, termination_deltas: ‘jax.Array’, cost_history: ‘jax.Array’, lambda_history: ‘jax.Array’)

iterations: Array#
termination_criteria: Array#
termination_deltas: Array#
cost_history: Array#

History of non-augmented costs (l2_squared terms only, excludes constraint penalties).

lambda_history: Array#

Cost#

class jaxls.Cost#

A cost or constraint term in our optimization problem.

The kind field determines how the residual function is interpreted:

  • "l2_squared" (default): Minimize squared L2 norm: ||r(x)||^2

  • "constraint_eq_zero": Equality constraint: r(x) = 0

  • "constraint_leq_zero": Inequality constraint: r(x) <= 0

  • "constraint_geq_zero": Inequality constraint: r(x) >= 0

Use the factory() decorator to create costs from a residual function.

Each Cost.compute_residual must include at least one jaxls.Var(id) in its inputs, where id is a scalar integer. Variables can appear anywhere in the input structure, including nested within pytrees (lists, dicts, dataclasses, etc.).

To create a batch of costs, a leading batch axis can be added to the arguments passed to Cost.args:

  • The batch axis must be the same for all arguments. Leading axes of shape (1,) are broadcasted.

  • The id field of each jaxls.Var instance must have shape of either () (unbatched) or (batch_size,) (batched).

static create_factory(compute_residual: ResidualFunc[Args_] | None = None, *, kind: CostKind = 'l2_squared', jac_mode: Literal['auto', 'forward', 'reverse'] = 'auto', jac_batch_size: int | None = None, jac_custom_fn: JacobianFunc[Args_] | None = None, jac_custom_with_cache_fn: JacobianFuncWithCache[Args_, Any] | None = None, name: str | None = None) Callable[[ResidualFunc[Args_]], CostFactory[Args_]] | Callable[[ResidualFuncWithJacCache[Args_, Any]], CostFactory[Args_]] | CostFactory[Args_]#

Deprecated: Use Cost.factory instead.

Parameters:
  • compute_residual (ResidualFunc[Args_] | None)

  • kind (CostKind)

  • jac_mode (Literal['auto', 'forward', 'reverse'])

  • jac_batch_size (int | None)

  • jac_custom_fn (JacobianFunc[Args_] | None)

  • jac_custom_with_cache_fn (JacobianFuncWithCache[Args_, Any] | None)

  • name (str | None)

Return type:

Callable[[ResidualFunc[Args_]], CostFactory[Args_]] | Callable[[ResidualFuncWithJacCache[Args_, Any]], CostFactory[Args_]] | CostFactory[Args_]

static factory(compute_residual: ResidualFunc[Args_] | None = None, *, kind: CostKind = 'l2_squared', jac_mode: Literal['auto', 'forward', 'reverse'] = 'auto', jac_batch_size: int | None = None, jac_custom_fn: JacobianFunc[Args_] | None = None, jac_custom_with_cache_fn: JacobianFuncWithCache[Args_, Any] | None = None, name: str | None = None) Callable[[ResidualFunc[Args_]], CostFactory[Args_]] | Callable[[ResidualFuncWithJacCache[Args_, Any]], CostFactory[Args_]] | CostFactory[Args_]#

Decorator for creating costs from a residual function.

The decorated function should take VarValues as its first argument and return a residual array. The resulting factory will have the same signature but without the VarValues argument.

Parameters:
  • kind (CostKind) – How to interpret the residual (default: "l2_squared").

  • jac_mode (Literal['auto', 'forward', 'reverse']) – Autodiff mode for Jacobians ("auto", "forward", or "reverse").

  • jac_batch_size (int | None) – Batch size for Jacobian computation. Set to 1 to reduce memory usage.

  • compute_residual (ResidualFunc[Args_] | None)

  • jac_custom_fn (JacobianFunc[Args_] | None)

  • jac_custom_with_cache_fn (JacobianFuncWithCache[Args_, Any] | None)

  • name (str | None)

Return type:

Callable[[ResidualFunc[Args_]], CostFactory[Args_]] | Callable[[ResidualFuncWithJacCache[Args_, Any]], CostFactory[Args_]] | CostFactory[Args_]

jac_batch_size: jdc.Static[int | None] = None#
jac_custom_fn: jdc.Static[Callable[[VarValues, *Args], jax.Array] | None] = None#
jac_custom_with_cache_fn: jdc.Static[Callable[[VarValues, Any, *Args], jax.Array] | None] = None#
jac_mode: jdc.Static[Literal['auto', 'forward', 'reverse']] = 'auto'#
kind: jdc.Static[CostKind] = 'l2_squared'#
static make(compute_residual: jdc.Static[Callable[[VarValues, *Args_], jax.Array]], args: tuple[*Args_,], jac_mode: jdc.Static[Literal['auto', 'forward', 'reverse']] = 'auto', jac_custom_fn: jdc.Static[Callable[[VarValues, *Args_], jax.Array] | None] = None) Cost[*Args_,]#
Parameters:
  • compute_residual (jdc.Static[Callable[[VarValues, *Args_], jax.Array]])

  • args (tuple[*Args_,])

  • jac_mode (jdc.Static[Literal['auto', 'forward', 'reverse']])

  • jac_custom_fn (jdc.Static[Callable[[VarValues, *Args_], jax.Array] | None])

Return type:

Cost[*Args_,]

name: jdc.Static[str | None] = None#
compute_residual: jdc.Static[Callable[[VarValues, *Args], jax.Array] | Callable[[VarValues, *Args], tuple[jax.Array, Any]]]#
args: tuple[*Args,]#

Variables#

class jaxls.Var[source]#

A symbolic representation of an optimization variable.

id: Array | int#
tangent_dim: ClassVar[int]#

Dimension of the tangent space.

retract_fn: ClassVar[Callable[[Any, Array], Any]]#

Retraction function for the manifold. None for Euclidean space.

classmethod default_factory() T[source]#

Default value for this variable.

Return type:

T

with_value(value: T) VarWithValue[T][source]#

Assign a value to this variable. Returned value can be used as input for VarValues.make().

Parameters:

value (T)

Return type:

VarWithValue[T]

class jaxls.VarValues[source]#

A mapping from variables to variable values.

Given a variable object var and a values object vals, we can get the value by calling one of:

# Equivalent.
vals.get_value(var)
vals[var]

To get all values of a specific type var_type, use:

# Equivalent.
vals.get_stacked_value(var_type)
vals[var_type]
vals_from_type: dict[type[Var[Any]], Any]#

Stacked values for each variable type. Will be sorted by ID (ascending).

ids_from_type: dict[type[Var[Any]], Array]#

Variable ID for each value, sorted in ascending order.

get_value(var: Var) T[source]#

Get the value of a specific variable or variables.

Parameters:

var (Var)

Return type:

T

get_stacked_value(var_type: type[Var]) T[source]#

Get the value of all variables of a specific type.

Parameters:

var_type (type[Var])

Return type:

T

static make(variables: Iterable[Var[Any] | VarWithValue[Any]]) VarValues[source]#

Create a VarValues object from a list of variables with or without values assigned to them. In the latter case, value are set to the default value of the variable type.

Example

>>> v1 = SomeVar(1)
>>> v2 = AnotherVar(2)
>>>
>>> # Set v1 to default, v2 to custom value:
>>> values = VarValues.make([v1, v2.with_value(custom_value)])
>>>
>>> # The previous example is equivalent to:
>>> values = VarValues.make([v1.with_value(v1.default_factory()), v2.with_value(custom_value)])
Parameters:

variables (Iterable[Var[Any] | VarWithValue[Any]])

Return type:

VarValues