Manifold allocation#

The mean-variance notebook treats allocation as a classical constrained optimization problem: we minimize variance subject to budget constraints (weights sum to 1) and no short-selling (weights non-negative).

An alternative perspective: the set of valid allocations forms a manifold called the probability simplex. By defining a custom retract_fn, we can optimize directly on this manifold without explicit constraints.

Features used:

  • Var with custom retract_fn and tangent_dim for manifold variables

  • Multiplicative retraction for smooth simplex optimization

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import jax
import jax.numpy as jnp
import jaxls

The probability simplex#

The probability simplex \(\Delta^{n-1}\) is the set of valid portfolio allocations:

\[\Delta^{n-1} = \{w \in \mathbb{R}^n : w_i \geq 0, \sum_i w_i = 1\}\]

This is an \((n-1)\)-dimensional manifold embedded in \(\mathbb{R}^n\). The sum-to-one constraint removes one degree of freedom.

To optimize on this manifold, we need a retraction: a function retract(x, delta) that maps a point on the manifold plus a tangent update back to the manifold. See Non-Euclidean variables for background on manifold optimization.

Simplex retraction#

Our retraction uses multiplicative updates, which are additive in log-space:

  1. Scale each weight: weights * exp(delta)

  2. Normalize: divide by sum to ensure sum-to-one

This satisfies the retraction requirements:

  • Identity at zero: exp(0) = 1, so retract(w, 0) = w / sum(w) = w

  • Stays on manifold: normalization ensures sum-to-one; exp() ensures positivity

  • Smoothness: both exp() and division are smooth operations

def simplex_retract(weights: jax.Array, delta: jax.Array) -> jax.Array:
    """Multiplicative update that keeps weights positive and summing to 1."""
    v = weights * jnp.exp(delta)
    return v / jnp.sum(v)
# Verify the retraction works.
test_weights = jnp.array([0.5, 0.3, 0.2])
test_delta = jnp.array([0.1, -0.4, 0.1])  # Would make middle weight negative.

result = simplex_retract(test_weights, test_delta)
print(f"Initial weights: {test_weights} (sum={float(jnp.sum(test_weights)):.4f})")
print(f"After adding delta: {test_weights + test_delta}")
print(f"After retraction: {result} (sum={float(jnp.sum(result)):.4f})")
print(f"All non-negative: {bool(jnp.all(result >= 0))}")
Initial weights: [0.5 0.3 0.2] (sum=1.0000)
After adding delta: [ 0.6        -0.09999999  0.3       ]
After retraction: [0.5669196  0.20631248 0.22676787] (sum=1.0000)
All non-negative: True

Simplex variable type#

Now we define a variable type that lives on the simplex. The retract_fn ensures all updates produce valid allocations:

n_assets = 3


class SimplexWeightsVar(
    jaxls.Var[jax.Array],
    default_factory=lambda: jnp.ones(n_assets) / n_assets,
    retract_fn=simplex_retract,
    tangent_dim=n_assets,  # One delta per weight; retraction enforces constraints.
):
    """Portfolio weights on the probability simplex."""


print(f"SimplexWeightsVar tangent_dim: {SimplexWeightsVar.tangent_dim}")
SimplexWeightsVar tangent_dim: 3

Asset data#

We use the same historical stock data as the mean-variance notebook.

stock_names = ["IBM", "WMT", "SEHI"]

# Monthly prices (13 months: Nov 2000 - Nov 2001).
prices = jnp.array(
    [
        [93.043, 51.826, 1.063],
        [84.585, 52.823, 0.938],
        [111.453, 56.477, 1.0],
        [99.525, 49.805, 0.938],
        [95.819, 50.287, 1.438],
        [114.708, 51.521, 1.7],
        [111.515, 51.531, 2.54],
        [113.211, 48.664, 2.39],
        [104.942, 55.744, 3.12],
        [99.827, 47.916, 2.98],
        [91.607, 49.438, 1.9],
        [107.937, 51.336, 1.75],
        [115.59, 55.081, 1.8],
    ]
)

# Compute returns and covariance.
returns = jnp.diff(prices, axis=0) / prices[:-1]
expected_returns = jnp.mean(returns, axis=0)
returns_centered = returns - expected_returns
covariance = (returns_centered.T @ returns_centered) / (returns.shape[0] - 1)
cov_chol = jnp.linalg.cholesky(covariance)

print("Expected monthly returns:")
for name, r in zip(stock_names, expected_returns):
    print(f"  {name}: {float(r) * 100:+.2f}%")
Expected monthly returns:
  IBM: +2.60%
  WMT: +0.81%
  SEHI: +7.37%

Problem formulation#

With the simplex variable, we only need a variance cost and a return constraint. The budget and no short-selling constraints are handled implicitly by the manifold geometry.

weights_var = SimplexWeightsVar(id=0)


@jaxls.Cost.factory
def variance_cost(
    vals: jaxls.VarValues, var: SimplexWeightsVar, cov_chol: jax.Array
) -> jax.Array:
    """Minimize portfolio variance: ||L.T @ w||^2 = w.T @ cov @ w."""
    return cov_chol.T @ vals[var]


@jaxls.Cost.factory(kind="constraint_geq_zero")
def return_constraint(
    vals: jaxls.VarValues, var: SimplexWeightsVar, exp_ret: jax.Array, target: float
) -> jax.Array:
    """Expected return must meet target: E[r] >= target."""
    return jnp.array([jnp.dot(vals[var], exp_ret) - target])

Efficient frontier#

We compute the efficient frontier by solving for different target returns. Note that we only have one constraint (return target) instead of three.

Using jax.lax.scan, we solve sequentially while using each solution as the initial guess for the next (warm-starting).

min_return = float(expected_returns.min())
max_return = float(expected_returns.max())
target_returns = jnp.linspace(min_return, max_return, 50)


def solve_for_target(
    current_vals: jaxls.VarValues, target: jax.Array
) -> tuple[jaxls.VarValues, jax.Array]:
    """Solve portfolio optimization for a given target return.

    Args:
        current_vals: Solution from previous target (used as initial guess).
        target: Target return for this solve.

    Returns:
        Tuple of (solution values, optimal weights).
    """
    costs = [
        variance_cost(weights_var, cov_chol),
        return_constraint(weights_var, expected_returns, target),
    ]
    problem = jaxls.LeastSquaresProblem(costs, [weights_var]).analyze()
    solution = problem.solve(
        current_vals,
        verbose=False,
        linear_solver="dense_cholesky",
        termination=jaxls.TerminationConfig(cost_tolerance=1e-8),
    )
    return solution, solution[weights_var]


# Solve sequentially with warm-starting.
initial_vals = jaxls.VarValues.make([weights_var])
_, all_weights = jax.lax.scan(solve_for_target, initial_vals, target_returns)
variances = jax.vmap(lambda w: w @ covariance @ w)(all_weights)
returns_achieved = jax.vmap(lambda w: jnp.dot(w, expected_returns))(all_weights)

print(f"Computed {len(target_returns)} points on the efficient frontier")
INFO     | Building optimization problem with 2 terms and 1 variables: 1 costs, 0 eq_zero, 0 leq_zero, 1 geq_zero
INFO     | Vectorizing group with 1 costs, 1 variables each: variance_cost
INFO     | Vectorizing constraint group with 1 constraints (constraint_geq_zero), 1 variables each: augmented_return_constraint
Computed 50 points on the efficient frontier
# Verify all solutions satisfy simplex constraints.
weight_sums = jnp.sum(all_weights, axis=1)
min_weights = jnp.min(all_weights, axis=1)

print(
    f"Weight sums: min={float(weight_sums.min()):.6f}, max={float(weight_sums.max()):.6f}"
)
print(f"Min weight across all solutions: {float(min_weights.min()):.6f}")
print("All constraints satisfied by construction!")
Weight sums: min=1.000000, max=1.000000
Min weight across all solutions: 0.000000
All constraints satisfied by construction!

Results#

Hide code cell source

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import HTML
import math

# Annualize returns and risk.
annual_factor = 12
risk_free_rate = 0.05  # 5% annual risk-free rate.

std_devs_annual = [float(jnp.sqrt(v)) * math.sqrt(annual_factor) for v in variances]
returns_annual = [float(r) * annual_factor for r in returns_achieved]
sharpe_ratios = [
    (r - risk_free_rate) / s if s > 0 else 0.0
    for r, s in zip(returns_annual, std_devs_annual)
]

colors = ["#2196F3", "#4CAF50", "#FF9800"]

fig = make_subplots(
    rows=1,
    cols=3,
    subplot_titles=("Efficient Frontier", "Sharpe Ratio", "Asset Allocation"),
    column_widths=[0.33, 0.33, 0.34],
)

# Scale to $1000 investment.
std_dev_dollars = [s * 1000 for s in std_devs_annual]
return_dollars = [r * 1000 for r in returns_annual]
target_dollars = [float(t) * annual_factor * 1000 for t in target_returns]

# Left plot: Efficient frontier.
fig.add_trace(
    go.Scatter(
        x=std_dev_dollars,
        y=return_dollars,
        mode="lines+markers",
        marker=dict(size=6, color=sharpe_ratios, colorscale="Viridis", showscale=False),
        line=dict(color="steelblue", width=2),
        hovertemplate="Std Dev: $%{x:.0f}<br>Return: $%{y:.0f}<extra></extra>",
        showlegend=False,
    ),
    row=1,
    col=1,
)

# Middle plot: Sharpe ratio vs return.
fig.add_trace(
    go.Scatter(
        x=return_dollars,
        y=sharpe_ratios,
        mode="lines+markers",
        marker=dict(size=6, color=sharpe_ratios, colorscale="Viridis", showscale=False),
        line=dict(color="steelblue", width=2),
        hovertemplate="Return: $%{x:.0f}<br>Sharpe: %{y:.2f}<extra></extra>",
        showlegend=False,
    ),
    row=1,
    col=2,
)

# Right plot: Asset allocation.
for i, (name, color) in enumerate(zip(stock_names, colors)):
    fig.add_trace(
        go.Bar(
            x=target_dollars,
            y=[float(w) * 1000 for w in all_weights[:, i]],
            name=name,
            marker_color=color,
            hovertemplate=f"{name}: $%{{y:.0f}}<extra></extra>",
        ),
        row=1,
        col=3,
    )

fig.update_xaxes(title_text="Std Dev ($)", row=1, col=1)
fig.update_yaxes(title_text="Return ($)", row=1, col=1)
fig.update_xaxes(title_text="Return ($)", row=1, col=2)
fig.update_yaxes(title_text="Sharpe Ratio", row=1, col=2)
fig.update_xaxes(title_text="Target Return ($)", row=1, col=3)
fig.update_yaxes(title_text="Investment ($)", range=[0, 1050], row=1, col=3)

fig.update_layout(
    barmode="stack",
    height=400,
    margin=dict(t=40, b=40, l=60, r=40),
    legend=dict(orientation="h", yanchor="bottom", y=-0.25, xanchor="center", x=0.83),
)
HTML(fig.to_html(full_html=False, include_plotlyjs="cdn"))

Comparison with constrained approach#

The manifold and constrained approaches produce the same efficient frontier. The key difference is in how constraints are enforced:

Approach

Budget constraint

No short-selling

Implementation

Constrained

Explicit equality constraint

Explicit inequality constraint

Augmented Lagrangian

Manifold

Built into retract_fn

Built into retract_fn

Multiplicative retraction

The manifold approach can be advantageous when:

  • Constraints define a smooth manifold with a simple retraction

  • You want to avoid the overhead of the Augmented Lagrangian solver