Mean-variance allocation#
In this notebook, we solve a mean-variance portfolio optimization problem: finding asset allocations that maximize return for a given level of risk, following the Markowitz framework.
This example is adapted from the JuMP portfolio optimization tutorial.
Features used:
Varwith vector-valued defaultEquality constraints (
constraint_eq_zero): budget constraintInequality constraints (
constraint_geq_zero): minimum return, no short-sellingAugmented Lagrangian solver for constrained optimization
Efficient frontier via parametric sweeps
import jax
import jax.numpy as jnp
import jaxls
Historical stock data#
Monthly stock prices from November 2000 to November 2001 for three stocks: IBM, Walmart (WMT), and Southern Electric (SEHI).
stock_names = ["IBM", "WMT", "SEHI"]
# Monthly prices (13 months: Nov 2000 - Nov 2001)
prices = jnp.array(
[
[93.043, 51.826, 1.063],
[84.585, 52.823, 0.938],
[111.453, 56.477, 1.0],
[99.525, 49.805, 0.938],
[95.819, 50.287, 1.438],
[114.708, 51.521, 1.7],
[111.515, 51.531, 2.54],
[113.211, 48.664, 2.39],
[104.942, 55.744, 3.12],
[99.827, 47.916, 2.98],
[91.607, 49.438, 1.9],
[107.937, 51.336, 1.75],
[115.59, 55.081, 1.8],
]
)
print(f"Price data shape: {prices.shape} (months × stocks)")
Price data shape: (13, 3) (months × stocks)
Computing returns and covariance#
Monthly returns are computed as percentage changes, then we estimate expected returns and the covariance matrix.
# Monthly returns: (P[t+1] - P[t]) / P[t]
returns = jnp.diff(prices, axis=0) / prices[:-1]
# Expected return (mean of monthly returns)
expected_returns = jnp.mean(returns, axis=0)
# Covariance matrix (sample covariance)
returns_centered = returns - expected_returns
covariance = (returns_centered.T @ returns_centered) / (returns.shape[0] - 1)
print("Expected monthly returns:")
for name, r in zip(stock_names, expected_returns):
print(f" {name}: {float(r) * 100:+.2f}%")
print(f"\nCovariance matrix:\n{covariance}")
Expected monthly returns:
IBM: +2.60%
WMT: +0.81%
SEHI: +7.37%
Covariance matrix:
[[0.01864104 0.00359853 0.00130976]
[0.00359853 0.00643694 0.00488726]
[0.00130976 0.00488726 0.06868275]]
Problem formulation#
We want to invest $1000 to minimize portfolio variance while achieving a target return.
Constraints:
Budget: total investment = $1000 (weights sum to 1)
Minimum return: expected return \(\geq\) target
No short-selling: all investments \(\geq\) 0
class WeightsVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.ones(3) / 3):
"""Portfolio weights (3D vector)."""
weights_var = WeightsVar(id=0)
@jaxls.Cost.factory
def variance_cost(
vals: jaxls.VarValues, var: WeightsVar, cov_chol: jax.Array
) -> jax.Array:
"""Minimize portfolio variance: ||L.T @ w||^2 = w.T @ cov @ w."""
return cov_chol.T @ vals[var]
@jaxls.Cost.factory(kind="constraint_eq_zero")
def budget_constraint(vals: jaxls.VarValues, var: WeightsVar) -> jax.Array:
"""Weights must sum to 1 (fully invested)."""
return jnp.sum(vals[var]) - 1.0
@jaxls.Cost.factory(kind="constraint_geq_zero")
def return_constraint(
vals: jaxls.VarValues, var: WeightsVar, exp_ret: jax.Array, target: float
) -> jax.Array:
"""Expected return must meet target: E[r] >= target."""
return jnp.array([jnp.dot(vals[var], exp_ret) - target])
@jaxls.Cost.factory(kind="constraint_geq_zero")
def no_short_constraint(vals: jaxls.VarValues, var: WeightsVar) -> jax.Array:
"""No short-selling: weights >= 0."""
return vals[var]
Efficient frontier#
The efficient frontier shows the optimal trade-off between risk (variance) and return. We compute it by solving the optimization problem for different target returns.
Using jax.lax.scan, we solve sequentially while using each solution as the
initial guess for the next (warm-starting). This helps convergence since
adjacent target returns have similar optimal allocations.
# Cholesky decomposition for variance cost
cov_chol = jnp.linalg.cholesky(covariance)
# Range of target returns to explore
min_return = float(expected_returns.min())
max_return = float(expected_returns.max())
target_returns = jnp.linspace(min_return, max_return, 50)
def solve_for_target(
current_vals: jaxls.VarValues, target: jax.Array
) -> tuple[jaxls.VarValues, jax.Array]:
"""Solve portfolio optimization for a given target return.
Args:
current_vals: Solution from previous target (used as initial guess).
target: Target return for this solve.
Returns:
Tuple of (solution values, optimal weights).
"""
costs = [
variance_cost(weights_var, cov_chol),
budget_constraint(weights_var),
return_constraint(weights_var, expected_returns, target),
no_short_constraint(weights_var),
]
problem = jaxls.LeastSquaresProblem(costs, [weights_var]).analyze()
# Use dense Cholesky solver for this small problem
solution = problem.solve(
current_vals,
verbose=False,
linear_solver="dense_cholesky",
termination=jaxls.TerminationConfig(cost_tolerance=1e-8),
)
return solution, solution[weights_var]
# Solve sequentially with warm-starting.
initial_vals = jaxls.VarValues.make([weights_var])
_, all_weights = jax.lax.scan(solve_for_target, initial_vals, target_returns)
variances = jax.vmap(lambda w: w @ covariance @ w)(all_weights)
returns_achieved = jax.vmap(lambda w: jnp.dot(w, expected_returns))(all_weights)
print(f"Computed {len(target_returns)} points on the efficient frontier")
INFO | Building optimization problem with 4 terms and 1 variables: 1 costs, 1 eq_zero, 0 leq_zero, 2 geq_zero
INFO | Vectorizing group with 1 costs, 1 variables each: variance_cost
INFO | Vectorizing constraint group with 1 constraints (constraint_geq_zero), 1 variables each: augmented_return_constraint
INFO | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_budget_constraint
INFO | Vectorizing constraint group with 1 constraints (constraint_geq_zero), 1 variables each: augmented_no_short_constraint
Computed 50 points on the efficient frontier
Results#
Three views of the efficient frontier:
Objective space: Standard deviation vs. expected return
Risk-adjusted return: Sharpe ratio (return/risk) along the frontier
Decision space: Asset allocation across the frontier
The efficient frontier shows the optimal trade-off between risk (standard deviation) and return. SEHI has the highest expected return but also highest risk, while WMT provides stability. The optimal allocation shifts from WMT-heavy (low risk) to SEHI-heavy (high return) as we move along the frontier.
For more details, see jaxls.Cost and jaxls.LeastSquaresProblem.