Mean-variance allocation#

In this notebook, we solve a mean-variance portfolio optimization problem: finding asset allocations that maximize return for a given level of risk, following the Markowitz framework.

This example is adapted from the JuMP portfolio optimization tutorial.

Features used:

  • Var with vector-valued default

  • Equality constraints (constraint_eq_zero): budget constraint

  • Inequality constraints (constraint_geq_zero): minimum return, no short-selling

  • Augmented Lagrangian solver for constrained optimization

  • Efficient frontier via parametric sweeps

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import jax
import jax.numpy as jnp
import jaxls

Historical stock data#

Monthly stock prices from November 2000 to November 2001 for three stocks: IBM, Walmart (WMT), and Southern Electric (SEHI).

stock_names = ["IBM", "WMT", "SEHI"]

# Monthly prices (13 months: Nov 2000 - Nov 2001)
prices = jnp.array(
    [
        [93.043, 51.826, 1.063],
        [84.585, 52.823, 0.938],
        [111.453, 56.477, 1.0],
        [99.525, 49.805, 0.938],
        [95.819, 50.287, 1.438],
        [114.708, 51.521, 1.7],
        [111.515, 51.531, 2.54],
        [113.211, 48.664, 2.39],
        [104.942, 55.744, 3.12],
        [99.827, 47.916, 2.98],
        [91.607, 49.438, 1.9],
        [107.937, 51.336, 1.75],
        [115.59, 55.081, 1.8],
    ]
)

print(f"Price data shape: {prices.shape} (months × stocks)")
Price data shape: (13, 3) (months × stocks)

Hide code cell source

import plotly.graph_objects as go
from IPython.display import HTML

# Create date range for x-axis (Nov 2000 - Nov 2001)
months = [
    "Nov '00",
    "Dec '00",
    "Jan '01",
    "Feb '01",
    "Mar '01",
    "Apr '01",
    "May '01",
    "Jun '01",
    "Jul '01",
    "Aug '01",
    "Sep '01",
    "Oct '01",
    "Nov '01",
]

colors = ["#2196F3", "#4CAF50", "#FF9800"]

fig = go.Figure()
for i, (name, color) in enumerate(zip(stock_names, colors)):
    fig.add_trace(
        go.Scatter(
            x=months,
            y=prices[:, i],
            mode="lines+markers",
            name=name,
            line=dict(color=color, width=2),
            marker=dict(size=6),
        )
    )

fig.update_layout(
    title="Historical Stock Prices",
    xaxis_title="Month",
    yaxis_title="Price ($)",
    height=350,
    margin=dict(t=40, b=40, l=60, r=40),
    legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5),
    hovermode="x unified",
)
HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

Computing returns and covariance#

Monthly returns are computed as percentage changes, then we estimate expected returns and the covariance matrix.

# Monthly returns: (P[t+1] - P[t]) / P[t]
returns = jnp.diff(prices, axis=0) / prices[:-1]

# Expected return (mean of monthly returns)
expected_returns = jnp.mean(returns, axis=0)

# Covariance matrix (sample covariance)
returns_centered = returns - expected_returns
covariance = (returns_centered.T @ returns_centered) / (returns.shape[0] - 1)

print("Expected monthly returns:")
for name, r in zip(stock_names, expected_returns):
    print(f"  {name}: {float(r) * 100:+.2f}%")
print(f"\nCovariance matrix:\n{covariance}")
Expected monthly returns:
  IBM: +2.60%
  WMT: +0.81%
  SEHI: +7.37%

Covariance matrix:
[[0.01864104 0.00359853 0.00130976]
 [0.00359853 0.00643694 0.00488726]
 [0.00130976 0.00488726 0.06868275]]

Problem formulation#

We want to invest $1000 to minimize portfolio variance while achieving a target return.

Constraints:

  • Budget: total investment = $1000 (weights sum to 1)

  • Minimum return: expected return \(\geq\) target

  • No short-selling: all investments \(\geq\) 0

class WeightsVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.ones(3) / 3):
    """Portfolio weights (3D vector)."""


weights_var = WeightsVar(id=0)
@jaxls.Cost.factory
def variance_cost(
    vals: jaxls.VarValues, var: WeightsVar, cov_chol: jax.Array
) -> jax.Array:
    """Minimize portfolio variance: ||L.T @ w||^2 = w.T @ cov @ w."""
    return cov_chol.T @ vals[var]


@jaxls.Cost.factory(kind="constraint_eq_zero")
def budget_constraint(vals: jaxls.VarValues, var: WeightsVar) -> jax.Array:
    """Weights must sum to 1 (fully invested)."""
    return jnp.sum(vals[var]) - 1.0


@jaxls.Cost.factory(kind="constraint_geq_zero")
def return_constraint(
    vals: jaxls.VarValues, var: WeightsVar, exp_ret: jax.Array, target: float
) -> jax.Array:
    """Expected return must meet target: E[r] >= target."""
    return jnp.array([jnp.dot(vals[var], exp_ret) - target])


@jaxls.Cost.factory(kind="constraint_geq_zero")
def no_short_constraint(vals: jaxls.VarValues, var: WeightsVar) -> jax.Array:
    """No short-selling: weights >= 0."""
    return vals[var]

Efficient frontier#

The efficient frontier shows the optimal trade-off between risk (variance) and return. We compute it by solving the optimization problem for different target returns.

Using jax.lax.scan, we solve sequentially while using each solution as the initial guess for the next (warm-starting). This helps convergence since adjacent target returns have similar optimal allocations.

# Cholesky decomposition for variance cost
cov_chol = jnp.linalg.cholesky(covariance)

# Range of target returns to explore
min_return = float(expected_returns.min())
max_return = float(expected_returns.max())
target_returns = jnp.linspace(min_return, max_return, 50)


def solve_for_target(
    current_vals: jaxls.VarValues, target: jax.Array
) -> tuple[jaxls.VarValues, jax.Array]:
    """Solve portfolio optimization for a given target return.

    Args:
        current_vals: Solution from previous target (used as initial guess).
        target: Target return for this solve.

    Returns:
        Tuple of (solution values, optimal weights).
    """
    costs = [
        variance_cost(weights_var, cov_chol),
        budget_constraint(weights_var),
        return_constraint(weights_var, expected_returns, target),
        no_short_constraint(weights_var),
    ]
    problem = jaxls.LeastSquaresProblem(costs, [weights_var]).analyze()
    # Use dense Cholesky solver for this small problem
    solution = problem.solve(
        current_vals,
        verbose=False,
        linear_solver="dense_cholesky",
        termination=jaxls.TerminationConfig(cost_tolerance=1e-8),
    )
    return solution, solution[weights_var]


# Solve sequentially with warm-starting.
initial_vals = jaxls.VarValues.make([weights_var])
_, all_weights = jax.lax.scan(solve_for_target, initial_vals, target_returns)
variances = jax.vmap(lambda w: w @ covariance @ w)(all_weights)
returns_achieved = jax.vmap(lambda w: jnp.dot(w, expected_returns))(all_weights)

print(f"Computed {len(target_returns)} points on the efficient frontier")
INFO     | Building optimization problem with 4 terms and 1 variables: 1 costs, 1 eq_zero, 0 leq_zero, 2 geq_zero
INFO     | Vectorizing group with 1 costs, 1 variables each: variance_cost
INFO     | Vectorizing constraint group with 1 constraints (constraint_geq_zero), 1 variables each: augmented_return_constraint
INFO     | Vectorizing constraint group with 1 constraints (constraint_eq_zero), 1 variables each: augmented_budget_constraint
INFO     | Vectorizing constraint group with 1 constraints (constraint_geq_zero), 1 variables each: augmented_no_short_constraint
Computed 50 points on the efficient frontier

Results#

Three views of the efficient frontier:

  1. Objective space: Standard deviation vs. expected return

  2. Risk-adjusted return: Sharpe ratio (return/risk) along the frontier

  3. Decision space: Asset allocation across the frontier

Hide code cell source

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import HTML
import math

# Annualize returns and risk (monthly data -> yearly)
annual_factor = 12
risk_free_rate = 0.05  # 5% annual risk-free rate.

# Convert JAX arrays to Python floats for Plotly.
std_devs_annual = [float(jnp.sqrt(v)) * math.sqrt(annual_factor) for v in variances]
returns_annual = [float(r) * annual_factor for r in returns_achieved]
sharpe_ratios = [
    (r - risk_free_rate) / s if s > 0 else 0.0
    for r, s in zip(returns_annual, std_devs_annual)
]

colors = ["#2196F3", "#4CAF50", "#FF9800"]

fig = make_subplots(
    rows=1,
    cols=3,
    subplot_titles=("Efficient Frontier", "Sharpe Ratio", "Asset Allocation"),
    column_widths=[0.33, 0.33, 0.34],
)

# Scale to $1000 investment.
std_dev_dollars = [s * 1000 for s in std_devs_annual]
return_dollars = [r * 1000 for r in returns_annual]
# Use target returns for bar chart x-axis to avoid overlapping bars.
# (achieved returns can be identical for low targets that fall below min-variance portfolio)
target_dollars = [float(t) * annual_factor * 1000 for t in target_returns]

# Left plot: Efficient frontier (std dev vs return)
fig.add_trace(
    go.Scatter(
        x=std_dev_dollars,
        y=return_dollars,
        mode="lines+markers",
        marker=dict(size=6, color=sharpe_ratios, colorscale="Viridis", showscale=False),
        line=dict(color="steelblue", width=2),
        hovertemplate="Std Dev: $%{x:.0f}<br>Return: $%{y:.0f}<extra></extra>",
        showlegend=False,
    ),
    row=1,
    col=1,
)

# Middle plot: Sharpe ratio vs return.
fig.add_trace(
    go.Scatter(
        x=return_dollars,
        y=sharpe_ratios,
        mode="lines+markers",
        marker=dict(size=6, color=sharpe_ratios, colorscale="Viridis", showscale=False),
        line=dict(color="steelblue", width=2),
        hovertemplate="Return: $%{x:.0f}<br>Sharpe: %{y:.2f}<extra></extra>",
        showlegend=False,
    ),
    row=1,
    col=2,
)

# Right plot: Asset allocation vs target return.
for i, (name, color) in enumerate(zip(stock_names, colors)):
    fig.add_trace(
        go.Bar(
            x=target_dollars,
            y=[float(w) * 1000 for w in all_weights[:, i]],
            name=name,
            marker_color=color,
            hovertemplate=f"{name}: $%{{y:.0f}}<extra></extra>",
        ),
        row=1,
        col=3,
    )

fig.update_xaxes(title_text="Std Dev ($)", row=1, col=1)
fig.update_yaxes(title_text="Return ($)", row=1, col=1)
fig.update_xaxes(title_text="Return ($)", row=1, col=2)
fig.update_yaxes(title_text="Sharpe Ratio", row=1, col=2)
fig.update_xaxes(title_text="Target Return ($)", row=1, col=3)
fig.update_yaxes(title_text="Investment ($)", range=[0, 1050], row=1, col=3)

fig.update_layout(
    barmode="stack",
    height=400,
    margin=dict(t=40, b=40, l=60, r=40),
    legend=dict(orientation="h", yanchor="bottom", y=-0.25, xanchor="center", x=0.83),
)
HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

The efficient frontier shows the optimal trade-off between risk (standard deviation) and return. SEHI has the highest expected return but also highest risk, while WMT provides stability. The optimal allocation shifts from WMT-heavy (low risk) to SEHI-heavy (high return) as we move along the frontier.

For more details, see jaxls.Cost and jaxls.LeastSquaresProblem.