Manifold allocation#
The mean-variance notebook treats allocation as a classical constrained optimization problem: we minimize variance subject to budget constraints (weights sum to 1) and no short-selling (weights non-negative).
An alternative perspective: the set of valid allocations forms a manifold called the
probability simplex. By defining a custom retract_fn, we can optimize directly on this manifold
without explicit constraints.
Features used:
Varwith customretract_fnandtangent_dimfor manifold variablesMultiplicative retraction for smooth simplex optimization
import jax
import jax.numpy as jnp
import jaxls
The probability simplex#
The probability simplex \(\Delta^{n-1}\) is the set of valid portfolio allocations:
This is an \((n-1)\)-dimensional manifold embedded in \(\mathbb{R}^n\). The sum-to-one constraint removes one degree of freedom.
To optimize on this manifold, we need a retraction: a function retract(x, delta) that maps
a point on the manifold plus a tangent update back to the manifold.
See Non-Euclidean variables for background on manifold optimization.
Simplex retraction#
Our retraction uses multiplicative updates, which are additive in log-space:
Scale each weight:
weights * exp(delta)Normalize: divide by sum to ensure sum-to-one
This satisfies the retraction requirements:
Identity at zero:
exp(0) = 1, soretract(w, 0) = w / sum(w) = wStays on manifold: normalization ensures sum-to-one;
exp()ensures positivitySmoothness: both
exp()and division are smooth operations
def simplex_retract(weights: jax.Array, delta: jax.Array) -> jax.Array:
"""Multiplicative update that keeps weights positive and summing to 1."""
v = weights * jnp.exp(delta)
return v / jnp.sum(v)
# Verify the retraction works.
test_weights = jnp.array([0.5, 0.3, 0.2])
test_delta = jnp.array([0.1, -0.4, 0.1]) # Would make middle weight negative.
result = simplex_retract(test_weights, test_delta)
print(f"Initial weights: {test_weights} (sum={float(jnp.sum(test_weights)):.4f})")
print(f"After adding delta: {test_weights + test_delta}")
print(f"After retraction: {result} (sum={float(jnp.sum(result)):.4f})")
print(f"All non-negative: {bool(jnp.all(result >= 0))}")
Initial weights: [0.5 0.3 0.2] (sum=1.0000)
After adding delta: [ 0.6 -0.09999999 0.3 ]
After retraction: [0.5669196 0.20631248 0.22676787] (sum=1.0000)
All non-negative: True
Simplex variable type#
Now we define a variable type that lives on the simplex. The retract_fn ensures all updates
produce valid allocations:
n_assets = 3
class SimplexWeightsVar(
jaxls.Var[jax.Array],
default_factory=lambda: jnp.ones(n_assets) / n_assets,
retract_fn=simplex_retract,
tangent_dim=n_assets, # One delta per weight; retraction enforces constraints.
):
"""Portfolio weights on the probability simplex."""
print(f"SimplexWeightsVar tangent_dim: {SimplexWeightsVar.tangent_dim}")
SimplexWeightsVar tangent_dim: 3
Asset data#
We use the same historical stock data as the mean-variance notebook.
stock_names = ["IBM", "WMT", "SEHI"]
# Monthly prices (13 months: Nov 2000 - Nov 2001).
prices = jnp.array(
[
[93.043, 51.826, 1.063],
[84.585, 52.823, 0.938],
[111.453, 56.477, 1.0],
[99.525, 49.805, 0.938],
[95.819, 50.287, 1.438],
[114.708, 51.521, 1.7],
[111.515, 51.531, 2.54],
[113.211, 48.664, 2.39],
[104.942, 55.744, 3.12],
[99.827, 47.916, 2.98],
[91.607, 49.438, 1.9],
[107.937, 51.336, 1.75],
[115.59, 55.081, 1.8],
]
)
# Compute returns and covariance.
returns = jnp.diff(prices, axis=0) / prices[:-1]
expected_returns = jnp.mean(returns, axis=0)
returns_centered = returns - expected_returns
covariance = (returns_centered.T @ returns_centered) / (returns.shape[0] - 1)
cov_chol = jnp.linalg.cholesky(covariance)
print("Expected monthly returns:")
for name, r in zip(stock_names, expected_returns):
print(f" {name}: {float(r) * 100:+.2f}%")
Expected monthly returns:
IBM: +2.60%
WMT: +0.81%
SEHI: +7.37%
Problem formulation#
With the simplex variable, we only need a variance cost and a return constraint. The budget and no short-selling constraints are handled implicitly by the manifold geometry.
weights_var = SimplexWeightsVar(id=0)
@jaxls.Cost.factory
def variance_cost(
vals: jaxls.VarValues, var: SimplexWeightsVar, cov_chol: jax.Array
) -> jax.Array:
"""Minimize portfolio variance: ||L.T @ w||^2 = w.T @ cov @ w."""
return cov_chol.T @ vals[var]
@jaxls.Cost.factory(kind="constraint_geq_zero")
def return_constraint(
vals: jaxls.VarValues, var: SimplexWeightsVar, exp_ret: jax.Array, target: float
) -> jax.Array:
"""Expected return must meet target: E[r] >= target."""
return jnp.array([jnp.dot(vals[var], exp_ret) - target])
Efficient frontier#
We compute the efficient frontier by solving for different target returns. Note that we only have one constraint (return target) instead of three.
Using jax.lax.scan, we solve sequentially while using each solution as the
initial guess for the next (warm-starting).
min_return = float(expected_returns.min())
max_return = float(expected_returns.max())
target_returns = jnp.linspace(min_return, max_return, 50)
def solve_for_target(
current_vals: jaxls.VarValues, target: jax.Array
) -> tuple[jaxls.VarValues, jax.Array]:
"""Solve portfolio optimization for a given target return.
Args:
current_vals: Solution from previous target (used as initial guess).
target: Target return for this solve.
Returns:
Tuple of (solution values, optimal weights).
"""
costs = [
variance_cost(weights_var, cov_chol),
return_constraint(weights_var, expected_returns, target),
]
problem = jaxls.LeastSquaresProblem(costs, [weights_var]).analyze()
solution = problem.solve(
current_vals,
verbose=False,
linear_solver="dense_cholesky",
termination=jaxls.TerminationConfig(cost_tolerance=1e-8),
)
return solution, solution[weights_var]
# Solve sequentially with warm-starting.
initial_vals = jaxls.VarValues.make([weights_var])
_, all_weights = jax.lax.scan(solve_for_target, initial_vals, target_returns)
variances = jax.vmap(lambda w: w @ covariance @ w)(all_weights)
returns_achieved = jax.vmap(lambda w: jnp.dot(w, expected_returns))(all_weights)
print(f"Computed {len(target_returns)} points on the efficient frontier")
INFO | Building optimization problem with 2 terms and 1 variables: 1 costs, 0 eq_zero, 0 leq_zero, 1 geq_zero
INFO | Vectorizing group with 1 costs, 1 variables each: variance_cost
INFO | Vectorizing constraint group with 1 constraints (constraint_geq_zero), 1 variables each: augmented_return_constraint
Computed 50 points on the efficient frontier
# Verify all solutions satisfy simplex constraints.
weight_sums = jnp.sum(all_weights, axis=1)
min_weights = jnp.min(all_weights, axis=1)
print(
f"Weight sums: min={float(weight_sums.min()):.6f}, max={float(weight_sums.max()):.6f}"
)
print(f"Min weight across all solutions: {float(min_weights.min()):.6f}")
print("All constraints satisfied by construction!")
Weight sums: min=1.000000, max=1.000000
Min weight across all solutions: 0.000000
All constraints satisfied by construction!
Results#
Comparison with constrained approach#
The manifold and constrained approaches produce the same efficient frontier. The key difference is in how constraints are enforced:
Approach |
Budget constraint |
No short-selling |
Implementation |
|---|---|---|---|
Constrained |
Explicit equality constraint |
Explicit inequality constraint |
Augmented Lagrangian |
Manifold |
Built into |
Built into |
Multiplicative retraction |
The manifold approach can be advantageous when:
Constraints define a smooth manifold with a simple retraction
You want to avoid the overhead of the Augmented Lagrangian solver